multi agent debate -> multi agent collab + reset method in flow + updated tests for multiagentcollab, code interpreter fix in flow, loggers, eleven lab, tests, and docs

pull/170/head
Kye 1 year ago
parent 954e580af0
commit fddc12f828

@ -0,0 +1,82 @@
# ElevenLabsText2SpeechTool Documentation
## Table of Contents
1. [Introduction](#introduction)
2. [Class Overview](#class-overview)
- [Attributes](#attributes)
3. [Installation](#installation)
4. [Usage](#usage)
- [Initialization](#initialization)
- [Converting Text to Speech](#converting-text-to-speech)
- [Playing and Streaming Speech](#playing-and-streaming-speech)
5. [Exception Handling](#exception-handling)
6. [Advanced Usage](#advanced-usage)
7. [Contributing](#contributing)
8. [References](#references)
## 1. Introduction <a name="introduction"></a>
The `ElevenLabsText2SpeechTool` is a Python class designed to simplify the process of converting text to speech using the Eleven Labs Text2Speech API. This tool is a wrapper around the API and provides a convenient interface for generating speech from text. It supports multiple languages, making it suitable for a wide range of applications, including voice assistants, audio content generation, and more.
## 2. Class Overview <a name="class-overview"></a>
### Attributes <a name="attributes"></a>
- `model` (Union[ElevenLabsModel, str]): The model to use for text to speech. Defaults to `ElevenLabsModel.MULTI_LINGUAL`.
- `name` (str): The name of the tool. Defaults to `"eleven_labs_text2speech"`.
- `description` (str): A brief description of the tool. Defaults to a detailed explanation of its functionality.
## 3. Installation <a name="installation"></a>
To use the `ElevenLabsText2SpeechTool`, you need to install the required dependencies and have access to the Eleven Labs Text2Speech API. Follow these steps:
1. Install the `elevenlabs` library:
```
pip install elevenlabs
```
2. Install the `swarms` library
`pip install swarms`
3. Set up your API key by following the instructions at [Eleven Labs Documentation](https://docs.elevenlabs.io/welcome/introduction).
## 4. Usage <a name="usage"></a>
### Initialization <a name="initialization"></a>
To get started, create an instance of the `ElevenLabsText2SpeechTool`. You can customize the `model` attribute if needed.
```python
from swarms.models import ElevenLabsText2SpeechTool
stt = ElevenLabsText2SpeechTool(model=ElevenLabsModel.MONO_LINGUAL)
```
### Converting Text to Speech <a name="converting-text-to-speech"></a>
You can use the `run` method to convert text to speech. It returns the path to the generated speech file.
```python
speech_file = stt.run("Hello, this is a test.")
```
### Playing and Streaming Speech <a name="playing-and-streaming-speech"></a>
- Use the `play` method to play the generated speech file.
```python
stt.play(speech_file)
```
- Use the `stream_speech` method to stream the text as speech. It plays the speech in real-time.
```python
stt.stream_speech("Hello world!")
```
## 5. Exception Handling <a name="exception-handling"></a>
The `ElevenLabsText2SpeechTool` handles exceptions gracefully. If an error occurs during the conversion process, it raises a `RuntimeError` with an informative error message.
## 6. Advanced Usage <a name="advanced-usage"></a>
- You can implement custom error handling and logging to further enhance the functionality of this tool.
- For advanced users, extending the class to support additional features or customization is possible.
## 7. Contributing <a name="contributing"></a>
Contributions to this tool are welcome. Feel free to open issues, submit pull requests, or provide feedback to improve its functionality and documentation.
## 8. References <a name="references"></a>
- [Eleven Labs Text2Speech API Documentation](https://docs.elevenlabs.io/welcome/introduction)
This documentation provides a comprehensive guide to using the `ElevenLabsText2SpeechTool`. It covers installation, basic usage, advanced features, and contribution guidelines. Refer to the [References](#references) section for additional resources.

@ -33,26 +33,28 @@ doc_analyzer_agent = Flow(
sop=DOC_ANALYZER_AGENT_PROMPT,
max_loops=1,
autosave=True,
saved_state_path="doc_analyzer_agent.json"
saved_state_path="doc_analyzer_agent.json",
)
summary_generator_agent = Flow(
llm=llm2,
sop=SUMMARY_GENERATOR_AGENT_PROMPT,
max_loops=1,
autosave=True,
saved_state_path="summary_generator_agent.json"
saved_state_path="summary_generator_agent.json",
)
decision_making_support_agent = Flow(
llm=llm2,
sop=DECISION_MAKING_PROMPT,
max_loops=1,
saved_state_path="decision_making_support_agent.json"
saved_state_path="decision_making_support_agent.json",
)
pdf_path = "bankstatement.pdf"
fraud_detection_instructions = "Detect fraud in the document"
summary_agent_instructions = "Generate an actionable summary of the document with action steps to take"
summary_agent_instructions = (
"Generate an actionable summary of the document with action steps to take"
)
decision_making_support_agent_instructions = (
"Provide decision making support to the business owner:"
)

@ -16,7 +16,7 @@ load_dotenv()
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")
PDF_PATH = "videocon.pdf"
PDF_PATH = "fasterffn.pdf"
# Base llms
@ -43,8 +43,11 @@ paper_implementor_agent = Flow(
max_loops=1,
autosave=True,
saved_state_path="paper_implementor.json",
code_interpreter=False,
)
paper = pdf_to_text(PDF_PATH)
algorithmic_psuedocode_agent = paper_summarizer_agent.run(paper)
algorithmic_psuedocode_agent = paper_summarizer_agent.run(
f"Focus on creating the algorithmic pseudocode for the novel method in this paper: {paper}"
)
pytorch_code = paper_implementor_agent.run(algorithmic_psuedocode_agent)

@ -1,517 +0,0 @@
import os
from typing import Optional
import json
import os
import shutil
import time
import xml.etree.ElementTree as ET
import zipfile
from tempfile import mkdtemp
from typing import Dict, Optional
from urllib.parse import urlparse
import pyautogui
import requests
import semver
import undetected_chromedriver as uc # type: ignore
import yaml
from extension import load_extension
from pydantic import BaseModel
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.remote.webelement import WebElement
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from tqdm import tqdm
def _is_blank_agent(agent_name: str) -> bool:
with open(f"agents/{agent_name}.py", "r") as agent_file:
agent_data = agent_file.read()
with open("src/template.py", "r") as template_file:
template_data = template_file.read()
return agent_data == template_data
def record(agent_name: str, autotab_ext_path: Optional[str] = None):
if not os.path.exists("agents"):
os.makedirs("agents")
if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local":
if not _is_blank_agent(agent_name=agent_name):
raise Exception(f"Agent with name {agent_name} already exists")
driver = get_driver( # noqa: F841
autotab_ext_path=autotab_ext_path,
record_mode=True,
)
# Need to keep a reference to the driver so that it doesn't get garbage collected
with open("src/template.py", "r") as file:
data = file.read()
with open(f"agents/{agent_name}.py", "w") as file:
file.write(data)
print(
"\033[34mYou have the Python debugger open, you can run commands in it like you"
" would in a normal Python shell.\033[0m"
)
print(
"\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and"
" press enter.\033[0m"
)
breakpoint()
if __name__ == "__main__":
record("agent")
def extract_domain_from_url(url: str):
# url = http://username:password@hostname:port/path?arg=value#anchor
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if hostname is None:
raise ValueError(f"Could not extract hostname from url {url}")
if hostname.startswith("www."):
hostname = hostname[4:]
return hostname
class AutotabChromeDriver(uc.Chrome):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def find_element_with_retry(
self, by=By.ID, value: Optional[str] = None
) -> WebElement:
try:
return super().find_element(by, value)
except Exception as e:
# TODO: Use an LLM to retry, finding a similar element on the DOM
breakpoint()
raise e
def open_plugin(driver: AutotabChromeDriver):
print("Opening plugin sidepanel")
driver.execute_script("document.activeElement.blur();")
pyautogui.press("esc")
pyautogui.hotkey("command", "shift", "y", interval=0.05) # mypy: ignore
def open_plugin_and_login(driver: AutotabChromeDriver):
if config.autotab_api_key is not None:
backend_url = (
"http://localhost:8000"
if config.environment == "local"
else "https://api.autotab.com"
)
driver.get(f"{backend_url}/auth/signin-api-key-page")
response = requests.post(
f"{backend_url}/auth/signin-api-key",
json={"api_key": config.autotab_api_key},
)
cookie = response.json()
if response.status_code != 200:
if response.status_code == 401:
raise Exception("Invalid API key")
else:
raise Exception(
f"Error {response.status_code} from backend while logging you in"
f" with your API key: {response.text}"
)
cookie["name"] = cookie["key"]
del cookie["key"]
driver.add_cookie(cookie)
driver.get("https://www.google.com")
open_plugin(driver)
else:
print("No autotab API key found, heading to autotab.com to sign up")
url = (
"http://localhost:3000/dashboard"
if config.environment == "local"
else "https://autotab.com/dashboard"
)
driver.get(url)
time.sleep(0.5)
open_plugin(driver)
def get_driver(
autotab_ext_path: Optional[str] = None, record_mode: bool = False
) -> AutotabChromeDriver:
options = webdriver.ChromeOptions()
options.add_argument("--no-sandbox") # Necessary for running
options.add_argument(
"--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
" (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36"
)
options.add_argument("--enable-webgl")
options.add_argument("--enable-3d-apis")
options.add_argument("--enable-clipboard-read-write")
options.add_argument("--disable-popup-blocking")
if autotab_ext_path is None:
load_extension()
options.add_argument("--load-extension=./src/extension/autotab")
else:
options.add_argument(f"--load-extension={autotab_ext_path}")
options.add_argument("--allow-running-insecure-content")
options.add_argument("--disable-web-security")
options.add_argument(f"--user-data-dir={mkdtemp()}")
options.binary_location = config.chrome_binary_location
driver = AutotabChromeDriver(options=options)
if record_mode:
open_plugin_and_login(driver)
return driver
class SiteCredentials(BaseModel):
name: Optional[str] = None
email: Optional[str] = None
password: Optional[str] = None
login_with_google_account: Optional[str] = None
login_url: Optional[str] = None
def __init__(self, **data) -> None:
super().__init__(**data)
if self.name is None:
self.name = self.email
class GoogleCredentials(BaseModel):
credentials: Dict[str, SiteCredentials]
def __init__(self, **data) -> None:
super().__init__(**data)
for cred in self.credentials.values():
cred.login_url = "https://accounts.google.com/v3/signin"
@property
def default(self) -> SiteCredentials:
if "default" not in self.credentials:
if len(self.credentials) == 1:
return list(self.credentials.values())[0]
raise Exception("No default credentials found in config")
return self.credentials["default"]
class Config(BaseModel):
autotab_api_key: Optional[str]
credentials: Dict[str, SiteCredentials]
google_credentials: GoogleCredentials
chrome_binary_location: str
environment: str
@classmethod
def load_from_yaml(cls, path: str):
with open(path, "r") as config_file:
config = yaml.safe_load(config_file)
_credentials = {}
for domain, creds in config.get("credentials", {}).items():
if "login_url" not in creds:
creds["login_url"] = f"https://{domain}/login"
site_creds = SiteCredentials(**creds)
_credentials[domain] = site_creds
for alt in creds.get("alts", []):
_credentials[alt] = site_creds
google_credentials = {}
for creds in config.get("google_credentials", []):
credentials: SiteCredentials = SiteCredentials(**creds)
google_credentials[credentials.name] = credentials
chrome_binary_location = config.get("chrome_binary_location")
if chrome_binary_location is None:
raise Exception("Must specify chrome_binary_location in config")
autotab_api_key = config.get("autotab_api_key")
if autotab_api_key == "...":
autotab_api_key = None
return cls(
autotab_api_key=autotab_api_key,
credentials=_credentials,
google_credentials=GoogleCredentials(credentials=google_credentials),
chrome_binary_location=config.get("chrome_binary_location"),
environment=config.get("environment", "prod"),
)
def get_site_credentials(self, domain: str) -> SiteCredentials:
credentials = self.credentials[domain].copy()
return credentials
config = Config.load_from_yaml(".autotab.yaml")
def is_signed_in_to_google(driver):
cookies = driver.get_cookies()
return len([c for c in cookies if c["name"] == "SAPISID"]) != 0
def google_login(
driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True
):
print("Logging in to Google")
if navigate:
driver.get("https://accounts.google.com/")
time.sleep(1)
if is_signed_in_to_google(driver):
print("Already signed in to Google")
return
if os.path.exists("google_cookies.json"):
print("cookies exist, doing loading")
with open("google_cookies.json", "r") as f:
google_cookies = json.load(f)
for cookie in google_cookies:
if "expiry" in cookie:
cookie["expires"] = cookie["expiry"]
del cookie["expiry"]
driver.execute_cdp_cmd("Network.setCookie", cookie)
time.sleep(1)
driver.refresh()
time.sleep(2)
if not credentials:
credentials = config.google_credentials.default
if credentials is None:
raise Exception("No credentials provided for Google login")
email_input = driver.find_element(By.CSS_SELECTOR, "[type='email']")
email_input.send_keys(credentials.email)
email_input.send_keys(Keys.ENTER)
WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))
)
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
password_input.send_keys(credentials.password)
password_input.send_keys(Keys.ENTER)
time.sleep(1.5)
print("Successfully logged in to Google")
cookies = driver.get_cookies()
if not is_signed_in_to_google(driver):
# Probably wanted to have us solve a captcha, or 2FA or confirm recovery details
print("Need 2FA help to log in to Google")
# TODO: Show screenshot it to the user
breakpoint()
if not os.path.exists("google_cookies.json"):
print("Setting Google cookies for future use")
# Log out to have access to the right cookies
driver.get("https://accounts.google.com/Logout")
time.sleep(2)
cookies = driver.get_cookies()
cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"]
google_cookies = [
cookie
for cookie in cookies
if cookie["domain"] in [".google.com", "accounts.google.com"]
and cookie["name"] in cookie_names
]
with open("google_cookies.json", "w") as f:
json.dump(google_cookies, f)
# Log back in
login_button = driver.find_element(
By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']"
)
login_button.click()
time.sleep(1)
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
password_input.send_keys(credentials.password)
password_input.send_keys(Keys.ENTER)
time.sleep(3)
print("Successfully copied Google cookies for the future")
def login(driver, url: str):
domain = extract_domain_from_url(url)
credentials = config.get_site_credentials(domain)
login_url = credentials.login_url
if credentials.login_with_google_account:
google_credentials = config.google_credentials.credentials[
credentials.login_with_google_account
]
_login_with_google(driver, login_url, google_credentials)
else:
_login(driver, login_url, credentials=credentials)
def _login(driver, url: str, credentials: SiteCredentials):
print(f"Logging in to {url}")
driver.get(url)
time.sleep(2)
email_input = driver.find_element(By.NAME, "email")
email_input.send_keys(credentials.email)
password_input = driver.find_element(By.NAME, "password")
password_input.send_keys(credentials.password)
password_input.send_keys(Keys.ENTER)
time.sleep(3)
print(f"Successfully logged in to {url}")
def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
print(f"Logging in to {url} with Google")
google_login(driver, credentials=google_credentials)
driver.get(url)
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
main_window = driver.current_window_handle
xpath = (
"//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with"
" Google') or contains(@title, 'Sign in with Google')]"
)
WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath)))
driver.find_element(
By.XPATH,
xpath,
).click()
driver.switch_to.window(driver.window_handles[-1])
driver.find_element(
By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]"
).click()
driver.switch_to.window(main_window)
time.sleep(5)
print(f"Successfully logged in to {url}")
def update():
print("updating extension...")
# Download the autotab.crx file
response = requests.get(
"https://github.com/Planetary-Computers/autotab-extension/raw/main/autotab.crx",
stream=True,
)
# Check if the directory exists, if not create it
if os.path.exists("src/extension/.autotab"):
shutil.rmtree("src/extension/.autotab")
os.makedirs("src/extension/.autotab")
# Open the file in write binary mode
total_size = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
t = tqdm(total=total_size, unit="iB", unit_scale=True)
with open("src/extension/.autotab/autotab.crx", "wb") as f:
for data in response.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
if total_size != 0 and t.n != total_size:
print("ERROR, something went wrong")
# Unzip the file
with zipfile.ZipFile("src/extension/.autotab/autotab.crx", "r") as zip_ref:
zip_ref.extractall("src/extension/.autotab")
os.remove("src/extension/.autotab/autotab.crx")
if os.path.exists("src/extension/autotab"):
shutil.rmtree("src/extension/autotab")
os.rename("src/extension/.autotab", "src/extension/autotab")
def should_update():
if not os.path.exists("src/extension/autotab"):
return True
# Fetch the XML file
response = requests.get(
"https://raw.githubusercontent.com/Planetary-Computers/autotab-extension/main/update.xml"
)
xml_content = response.content
# Parse the XML file
root = ET.fromstring(xml_content)
namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces
xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version")
# Load the local JSON file
with open("src/extension/autotab/manifest.json", "r") as f:
json_content = json.load(f)
json_version = json_content["version"]
# Compare versions
return semver.compare(xml_version, json_version) > 0
def load_extension():
should_update() and update()
if __name__ == "__main__":
print("should update:", should_update())
update()
def play(agent_name: Optional[str] = None):
if agent_name is None:
agent_files = os.listdir("agents")
if len(agent_files) == 0:
raise Exception("No agents found in agents/ directory")
elif len(agent_files) == 1:
agent_file = agent_files[0]
else:
print("Found multiple agent files, please select one:")
for i, file in enumerate(agent_files, start=1):
print(f"{i}. {file}")
selected = int(input("Select a file by number: ")) - 1
agent_file = agent_files[selected]
else:
agent_file = f"{agent_name}.py"
os.system(f"python agents/{agent_file}")
if __name__ == "__main__":
play()
"""
chrome_binary_location: /Applications/Google Chrome.app/Contents/MacOS/Google Chrome
autotab_api_key: ... # Go to https://autotab.com/dashboard to get your API key, or
# run `autotab record` with this field blank and you will be prompted to log in to autotab
# Optional, programmatically login to services using "Login with Google" authentication
google_credentials:
- name: default
email: ...
password: ...
# Optional, specify alternative accounts to use with Google login on a per-service basis
- email: you@gmail.com # Credentials without a name use email as key
password: ...
credentials:
notion.so:
alts:
- notion.com
login_with_google_account: default
figma.com:
email: ...
password: ...
airtable.com:
login_with_google_account: you@gmail.com
"""

@ -3,4 +3,4 @@ QDRANT MEMORY CLASS
"""
"""

@ -0,0 +1,104 @@
import tempfile
from enum import Enum
from typing import Any, Dict, Union
from langchain.utils import get_from_dict_or_env
from pydantic import root_validator
from swarms.tools.tool import BaseTool
def _import_elevenlabs() -> Any:
try:
import elevenlabs
except ImportError as e:
raise ImportError(
"Cannot import elevenlabs, please install `pip install elevenlabs`."
) from e
return elevenlabs
class ElevenLabsModel(str, Enum):
"""Models available for Eleven Labs Text2Speech."""
MULTI_LINGUAL = "eleven_multilingual_v1"
MONO_LINGUAL = "eleven_monolingual_v1"
class ElevenLabsText2SpeechTool(BaseTool):
"""Tool that queries the Eleven Labs Text2Speech API.
In order to set this up, follow instructions at:
https://docs.elevenlabs.io/welcome/introduction
Attributes:
model (ElevenLabsModel): The model to use for text to speech.
Defaults to ElevenLabsModel.MULTI_LINGUAL.
name (str): The name of the tool. Defaults to "eleven_labs_text2speech".
description (str): The description of the tool.
Defaults to "A wrapper around Eleven Labs Text2Speech. Useful for when you need to convert text to speech. It supports multiple languages, including English, German, Polish, Spanish, Italian, French, Portuguese, and Hindi."
Usage:
>>> from swarms.models import ElevenLabsText2SpeechTool
>>> stt = ElevenLabsText2SpeechTool()
>>> speech_file = stt.run("Hello world!")
>>> stt.play(speech_file)
>>> stt.stream_speech("Hello world!")
"""
model: Union[ElevenLabsModel, str] = ElevenLabsModel.MULTI_LINGUAL
name: str = "eleven_labs_text2speech"
description: str = (
"A wrapper around Eleven Labs Text2Speech. "
"Useful for when you need to convert text to speech. "
"It supports multiple languages, including English, German, Polish, "
"Spanish, Italian, French, Portuguese, and Hindi. "
)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
_ = get_from_dict_or_env(values, "eleven_api_key", "ELEVEN_API_KEY")
return values
def _run(
self,
task: str,
) -> str:
"""Use the tool."""
elevenlabs = _import_elevenlabs()
try:
speech = elevenlabs.generate(text=task, model=self.model)
with tempfile.NamedTemporaryFile(
mode="bx", suffix=".wav", delete=False
) as f:
f.write(speech)
return f.name
except Exception as e:
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
def play(self, speech_file: str) -> None:
"""Play the text as speech."""
elevenlabs = _import_elevenlabs()
with open(speech_file, mode="rb") as f:
speech = f.read()
elevenlabs.play(speech)
def stream_speech(self, query: str) -> None:
"""Stream the text as speech as it is generated.
Play the text in your speakers."""
elevenlabs = _import_elevenlabs()
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
elevenlabs.stream(speech_stream)
def save(self, speech_file: str, path: str) -> None:
"""Save the speech file to a path."""
raise NotImplementedError("Saving not implemented for this tool.")
def __str__(self):
return "ElevenLabsText2SpeechTool"

@ -93,7 +93,7 @@ class HuggingfaceLLM:
set_logger(logger):
Set logger.
Examples:
>>> llm = HuggingfaceLLM(
... model_id="EleutherAI/gpt-neo-2.7B",
@ -426,19 +426,19 @@ class HuggingfaceLLM:
def set_distributed(self, distributed):
"""Set distributed"""
self.distributed = distributed
def set_decoding(self, decoding):
"""Set decoding"""
self.decoding = decoding
def set_max_workers(self, max_workers):
"""Set max_workers"""
self.max_workers = max_workers
def set_repitition_penalty(self, repitition_penalty):
"""Set repitition_penalty"""
self.repitition_penalty = repitition_penalty
def set_no_repeat_ngram_size(self, no_repeat_ngram_size):
"""Set no_repeat_ngram_size"""
self.no_repeat_ngram_size = no_repeat_ngram_size
@ -458,7 +458,7 @@ class HuggingfaceLLM:
def set_quantize(self, quantize):
"""Set quantize"""
self.quantize = quantize
def set_quantization_config(self, quantization_config):
"""Set quantization_config"""
self.quantization_config = quantization_config
@ -477,4 +477,4 @@ class HuggingfaceLLM:
def set_logger(self, logger):
"""Set logger"""
self.logger = logger
self.logger = logger

@ -0,0 +1,185 @@
# Agent process automation
system_prompt_1 = """You are a RPA(Robotic Process Automation) agent, you can write and test a RPA-Python-Code to connect different APPs together to reach a specific user query.
RPA-Python-Code:
1. Each actions and triggers of APPs are defined as Action/Trigger-Functions, once you provide the specific_params for a function, then we will implement and test it **with some features that can influence outside-world and is transparent to you**.
2. A RPA process is implemented as a workflow-function. the mainWorkflow function is activated when the trigger's conditions are reached.
3. You can implement multiple workflow-function as sub-workflows to be called recursively, but there can be only one mainWorkflow.
4. We will automatically test the workflows and actions with the Pinned-Data afer you change the specific_params.
Action/Trigger-Function: All the functions have the same following parameters:
1.integration_name: where this function is from. A integration represent a list of actions and triggers from a APP.
2.resource_name: This is the second category of a integration.
3.operation_name: This is the third category of a integration. (integration->resouce->operation)
4.specific_params: This is a json field, you will only see how to given this field after the above fields are selected.
5.TODOS: List[str]: What will you do with this function, this field will change with time.
6.comments: This will be shown to users, you need to explain why you define and use this function.
Workflow-Function:
1. Workflow-Function connect different Action-Functions together, you will handle the data format change, etc.
2. You must always have a mainWorkflow, whose inputs are a Trigger-function's output. If you define multiple triggers, The mainWorkflow will be activated when one of the trigger are activated, you must handle data type changes.
3. You can define multiple subworkflow-function, Which whose inputs are provided by other workflows, You need to handle data-formats.
Testing-When-Implementing: We will **automatically** test all your actions, triggers and workflows with the pinned input data **at each time** once you change it.
1. Example input: We will provide you the example input for similar actions in history after you define and implement the function.
2. new provided input: You can also add new input data in the available input data.
3. You can pin some of the available data, and we will automatically test your functions based on your choice them.
4. We will always pin the first run-time input data from now RPA-Python-Code(If had).
5.Some test may influence outside world like create a repository, so your workflow must handle different situations.
Data-Format: We ensure all the input/output data in transparent action functions have the format of List of Json: [{...}], length > 0
1.All items in the list have the same json schema. The transparent will be activated for each item in the input-data. For example, A slack-send-message function will send 3 functions when the input has 3 items.
2.All the json item must have a "json" field, in which are some custom fields.
3.Some functions' json items have a additional "binary" field, which contains raw data of images, csv, etc.
4.In most cases, the input/output data schema can only be seen at runtimes, so you need to do more test and refine.
Java-Script-Expression:
1.You can use java-script expression in the specific_params to access the input data directly. Use it by a string startswith "=", and provide expression inside a "{{...}}" block.
2. Use "{{$json["xxx"]}}" to obtain the "json" field in each item of the input data.
3. You can use expression in "string" , "number", "boolean" and "json" type, such as:
string: "=Hello {{$json["name"]}}, you are {{$json["age"]}} years old
boolean: "={{$json["age"] > 20}}"
number: "={{$json["year"] + 10.5}}"
json: "={ "new_age":{{$json["year"] + 5}} }"
For example, in slack-send-message. The input looks like:
[
{
"json": {
"name": "Alice",
"age": 15,
}
},
{
"json": {
"name": "Jack",
"age": 20,
}
}
]
When you set the field "message text" as "=Hello {{$json["name"]}}, you are {{$json["age"]}} years old.", then the message will be send as:
[
"Hello Alice, you are 15 years old.",
"Hello Jack, you are 20 years old.",
]
Based on the above information, the full RPA-Python-Code looks like the following:
```
from transparent_server import transparent_action, tranparent_trigger
# Specific_params: After you give function_define, we will provide json schemas of specific_params here.
# Avaliable_datas: All the avaliable Datas: data_1, data_2...
# Pinned_data_ID: All the input data you pinned and there execution result
# ID=1, output: xxx
# ID=3, output: xxx
# Runtime_input_data: The runtime input of this function(first time)
# Runtime_output_data: The corresponding output
def action_1(input_data: [{...}]):
# comments: some comments to users. Always give/change this when defining and implmenting
# TODOS:
# 1. I will provide the information in runtime
# 2. I will test the node
# 3. ...Always give/change this when defining and implmenting
specific_params = {
"key_1": value_1,
"key_2": [
{
"subkey_2": value_2,
}
],
"key_3": {
"subkey_3": value_3,
},
# You will implement this after function-define
}
function = transparent_action(integration=xxx, resource=yyy, operation=zzz)
output_data = function.run(input_data=input_data, params=params)
return output_data
def action_2(input_data: [{...}]): ...
def action_3(input_data: [{...}]): ...
def action_4(input_data: [{...}]): ...
# Specific_params: After you give function_define, we will provide json schemas of specific_params here.
# Trigger function has no input, and have the same output_format. So We will provide You the exmaple_output once you changed the code here.
def trigger_1():
# comments: some comments to users. Always give/change this when defining and implmenting
# TODOS:
# 1. I will provide the information in runtime
# 2. I will test the node
# 3. ...Always give/change this when defining and implmenting
specific_params = {
"key_1": value_1,
"key_2": [
{
"subkey_2": value_2,
}
],
"key_3": {
"subkey_3": value_3,
},
# You will implement this after function-define
}
function = transparent_trigger(integration=xxx, resource=yyy, operation=zzz)
output_data = function.run(input_data=input_data, params=params)
return output_data
def trigger_2(input_data: [{...}]): ...
def trigger_3(input_data: [{...}]): ...
# subworkflow inputs the same json-schema, can be called by another workflow.
def subworkflow_1(father_workflow_input: [{...}]): ...
def subworkflow_2(father_workflow_input: [{...}]): ...
# If you defined the trigger node, we will show you the mocked trigger input here.
# If you have implemented the workflow, we will automatically run the workflow for all the mock trigger-input and tells you the result.
def mainWorkflow(trigger_input: [{...}]):
# comments: some comments to users. Always give/change this when defining and implmenting
# TODOS:
# 1. I will provide the information in runtime
# 2. I will test the node
# 3. ...Always give/change this when defining and implmenting
# some complex logics here
output_data = trigger_input
return output_data
```
"""
system_prompt_2 = """You will define and implement functions progressively for many steps. At each step, you can do one of the following actions:
1. functions_define: Define a list of functions(Action and Trigger). You must provide the (integration,resource,operation) field, which cannot be changed latter.
2. function_implement: After function define, we will provide you the specific_param schema of the target function. You can provide(or override) the specific_param by this function. We will show your available test_data after you implement functions.
3. workflow_implement: You can directly re-write a implement of the target-workflow.
4. add_test_data: Beside the provided hostory data, you can also add your custom test data for a function.
5. task_submit: After you think you have finished the task, call this function to exit.
Remember:
1.Always provide thought, plans and criticisim before giving an action.
2.Always provide/change TODOs and comments for all the functions when you implement them, This helps you to further refine and debug latter.
3.We will test functions automatically, you only need to change the pinned data.
"""
system_prompt_3 = """The user query:
{{user_query}}
You have access to use the following actions and triggers:
{{flatten_tools}}
"""
history_prompt = """In the {{action_count}}'s time, You made the following action:
{{action}}
"""
user_prompt = """Now the codes looks like this:
```
{{now_codes}}
```
{{refine_prompt}}
Give your next action together with thought, plans and criticisim:
"""

@ -2,5 +2,6 @@
# from swarms.structs.task import Task
from swarms.structs.flow import Flow
from swarms.structs.sequential_workflow import SequentialWorkflow
from swarms.structs.autoscaler import AutoScaler
__all__ = ["Flow", "SequentialWorkflow"]
__all__ = ["Flow", "SequentialWorkflow", "AutoScaler"]

@ -9,6 +9,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from termcolor import colored
from swarms.utils.code_interpreter import SubprocessCodeInterpreter
from swarms.utils.parse_code import extract_code_in_backticks_in_string
# Prompts
DYNAMIC_STOP_PROMPT = """
When you have finished the task from the Human, output a special token: <DONE>
@ -120,6 +123,55 @@ class Flow:
dynamic_temperature(bool): Dynamical temperature handling
**kwargs (Any): Any additional keyword arguments
Methods:
run: Run the autonomous agent loop
run_concurrent: Run the autonomous agent loop concurrently
bulk_run: Run the autonomous agent loop in bulk
save: Save the flow history to a file
load: Load the flow history from a file
validate_response: Validate the response based on certain criteria
print_history_and_memory: Print the history and memory of the flow
step: Execute a single step in the flow interaction
graceful_shutdown: Gracefully shutdown the system saving the state
run_with_timeout: Run the loop but stop if it takes longer than the timeout
analyze_feedback: Analyze the feedback for issues
undo_last: Response the last response and return the previous state
add_response_filter: Add a response filter to filter out certain words from the response
apply_reponse_filters: Apply the response filters to the response
filtered_run: Filter the response
interactive_run: Interactive run mode
streamed_generation: Stream the generation of the response
get_llm_params: Extracts and returns the parameters of the llm object for serialization.
agent_history_prompt: Generate the agent history prompt
add_task_to_memory: Add the task to the memory
add_message_to_memory: Add the message to the memory
add_message_to_memory_and_truncate: Add the message to the memory and truncate
print_dashboard: Print dashboard
activate_autonomous_agent: Print the autonomous agent activation message
dynamic_temperature: Dynamically change the temperature
_check_stopping_condition: Check if the stopping condition is met
format_prompt: Format the prompt
get_llm_init_params: Get the llm init params
provide_feedback: Allow users to provide feedback on the responses
truncate_history: Take the history and truncate it to fit into the model context length
agent_history_prompt: Generate the agent history prompt
extract_tool_commands: Extract the tool commands from the text
parse_and_execute_tools: Parse and execute the tools
execute_tools: Execute the tool with the provided parameters
construct_dynamic_prompt: Construct the dynamic prompt
get_tool_description: Get the tool description
find_tool_by_name: Find a tool by name
parse_tool_command: Parse the text for tool usage
dynamic_temperature: Dynamically change the temperature
_run: Generate a result using the provided keyword args.
from_llm_and_template: Create FlowStream from LLM and a string template.
from_llm_and_template_file: Create FlowStream from LLM and a template file.
save_state: Save the state of the flow
load_state: Load the state of the flow
run_async: Run the flow asynchronously
arun: Run the flow asynchronously
run_code: Run the code in the response
Example:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import Flow
@ -161,6 +213,7 @@ class Flow:
context_length: int = 8192,
user_name: str = "Human:",
self_healing: bool = False,
code_interpreter: bool = False,
**kwargs: Any,
):
self.llm = llm
@ -193,6 +246,8 @@ class Flow:
self.autosave = autosave
self.response_filters = []
self.self_healing = self_healing
self.code_interpreter = code_interpreter
self.code_executor = SubprocessCodeInterpreter()
def provide_feedback(self, feedback: str) -> None:
"""Allow users to provide feedback on the responses."""
@ -446,6 +501,9 @@ class Flow:
task,
**kwargs,
)
if self.code_interpreter:
self.run_code(response)
# If there are any tools then parse and execute them
# if self.tools:
# self.parse_and_execute_tools(response)
@ -537,6 +595,9 @@ class Flow:
task,
**kwargs,
)
if self.code_interpreter:
self.run_code(response)
# If there are any tools then parse and execute them
# if self.tools:
# self.parse_and_execute_tools(response)
@ -1032,6 +1093,80 @@ class Flow:
"""Update the retry interval"""
self.retry_interval = retry_interval
def reset(self):
"""Reset the flow"""
self.memory = []
def run_code(self, code: str):
"""
text -> parse_code by looking for code inside 6 backticks `````-> run_code
"""
parsed_code = extract_code_in_backticks_in_string(code)
run_code = self.code_executor.run(parsed_code)
return run_code
def tool_prompt_prep(self, api_docs: str = None, required_api: str = None):
"""
Prepare the tool prompt
"""
PROMPT = f"""
# Task
You will be provided with a list of APIs. These APIs will have a
description and a list of parameters and return types for each tool. Your
task involves creating 3 varied, complex, and detailed user scenarios
that require at least 5 API calls to complete involving at least 3
different APIs. One of these APIs will be explicitly provided and the
other two will be chosen by you.
For instance, given the APIs: SearchHotels, BookHotel, CancelBooking,
GetNFLNews. Given that GetNFLNews is explicitly provided, your scenario
should articulate something akin to:
"The user wants to see if the Broncos won their last game (GetNFLNews).
They then want to see if that qualifies them for the playoffs and who
they will be playing against (GetNFLNews). The Broncos did make it into
the playoffs, so the user wants watch the game in person. They want to
look for hotels where the playoffs are occurring (GetNBANews +
SearchHotels). After looking at the options, the user chooses to book a
3-day stay at the cheapest 4-star option (BookHotel)."
13
This scenario exemplifies a scenario using 5 API calls. The scenario is
complex, detailed, and concise as desired. The scenario also includes two
APIs used in tandem, the required API, GetNBANews to search for the
playoffs location and SearchHotels to find hotels based on the returned
location. Usage of multiple APIs in tandem is highly desirable and will
receive a higher score. Ideally each scenario should contain one or more
instances of multiple APIs being used in tandem.
Note that this scenario does not use all the APIs given and re-uses the "
GetNBANews" API. Re-using APIs is allowed, but each scenario should
involve at least 3 different APIs. Note that API usage is also included
in the scenario, but exact parameters are not necessary. You must use a
different combination of APIs for each scenario. All APIs must be used in
at least one scenario. You can only use the APIs provided in the APIs
section.
Note that API calls are not explicitly mentioned and their uses are
included in parentheses. This behaviour should be mimicked in your
response.
Deliver your response in this format:
- Scenario 1: <Scenario1>
- Scenario 2: <Scenario2>
- Scenario 3: <Scenario3>
# APIs
{api_docs}
# Response
Required API: {required_api}
Scenarios with >=5 API calls:
- Scenario 1: <Scenario1>
"""
def self_healing(self, **kwargs):
"""
Self healing by debugging errors and refactoring its own code
@ -1041,9 +1176,52 @@ class Flow:
"""
pass
def refactor_code(self):
"""
Refactor the code
"""
# Add your code here to refactor the code
pass
# def refactor_code(
# self,
# file: str,
# changes: List,
# confirm: bool = False
# ):
# """
# Refactor the code
# """
# with open(file) as f:
# original_file_lines = f.readlines()
# # Filter out the changes that are not confirmed
# operation_changes = [
# change for change in changes if "operation" in change
# ]
# explanations = [
# change["explanation"] for change in changes if "explanation" in change
# ]
# # Sort the changes in reverse line order
# # explanations.sort(key=lambda x: x["line", reverse=True])
# # def error_prompt_inject(
# # self,
# # file_path: str,
# # args: List,
# # error: str,
# # ):
# # with open(file_path, "r") as f:
# # file_lines = f.readlines()
# # file_with_lines = []
# # for i, line in enumerate(file_lines):
# # file_with_lines.append(str(i + 1) + "" + line)
# # file_with_lines = "".join(file_with_lines)
# # prompt = f"""
# # Here is the script that needs fixing:\n\n
# # {file_with_lines}\n\n
# # Here are the arguments it was provided:\n\n
# # {args}\n\n
# # Here is the error message:\n\n
# # {error}\n
# # "Please provide your suggested changes, and remember to stick to the "
# # exact format as described above.
# # """
# # # Print(prompt)

@ -1,17 +1,14 @@
from swarms.swarms.dialogue_simulator import DialogueSimulator
from swarms.swarms.autoscaler import AutoScaler
# from swarms.swarms.orchestrate import Orchestrator
from swarms.structs.autoscaler import AutoScaler
from swarms.swarms.god_mode import GodMode
from swarms.swarms.simple_swarm import SimpleSwarm
from swarms.swarms.multi_agent_debate import MultiAgentDebate, select_speaker
from swarms.swarms.multi_agent_collab import MultiAgentCollaboration
__all__ = [
"DialogueSimulator",
"AutoScaler",
# "Orchestrator",
"GodMode",
"SimpleSwarm",
"MultiAgentDebate",
"select_speaker",
"MultiAgentCollaboration",
]

@ -1,7 +1,13 @@
import json
import random
from typing import List
import tenacity
from langchain.output_parsers import RegexParser
from swarms.structs.flow import Flow
from swarms.utils.logger import logger
# utils
class BidOutputParser(RegexParser):
@ -17,7 +23,7 @@ bid_parser = BidOutputParser(
)
def select_next_speaker(step: int, agents, director) -> int:
def select_next_speaker_director(step: int, agents, director) -> int:
# if the step if even => director
# => director selects next speaker
if step % 2 == 1:
@ -27,27 +33,81 @@ def select_next_speaker(step: int, agents, director) -> int:
return idx
# Define a selection function
def select_speaker_round_table(step: int, agents) -> int:
# This function selects the speaker in a round-robin fashion
return step % len(agents)
# main
class MultiAgentCollaboration:
"""
Multi-agent collaboration class.
Attributes:
agents (List[Flow]): The agents in the collaboration.
selection_function (callable): The function that selects the next speaker.
Defaults to select_next_speaker.
max_iters (int): The maximum number of iterations. Defaults to 10.
Methods:
reset: Resets the state of all agents.
inject: Injects a message into the collaboration.
inject_agent: Injects an agent into the collaboration.
step: Steps through the collaboration.
ask_for_bid: Asks an agent for a bid.
select_next_speaker: Selects the next speaker.
run: Runs the collaboration.
format_results: Formats the results of the run method.
Usage:
>>> from swarms.models import MultiAgentCollaboration
>>> from swarms.models import Flow
>>> from swarms.models import OpenAIChat
>>> from swarms.models import Anthropic
"""
def __init__(
self,
agents,
selection_function,
agents: List[Flow],
selection_function: callable = select_next_speaker_director,
max_iters: int = 10,
autosave: bool = True,
saved_file_path_name: str = "multi_agent_collab.json",
stopping_token: str = "<DONE>",
logging: bool = True,
):
self.agents = agents
self._step = 0
self.select_next_speaker = selection_function
self._step = 0
self.max_iters = max_iters
self.autosave = autosave
self.saved_file_path_name = saved_file_path_name
self.stopping_token = stopping_token
self.results = []
self.logger = logger
self.logging = logging
def reset(self):
"""Resets the state of all agents."""
for agent in self.agents:
agent.reset()
def inject(self, name: str, message: str):
"""Injects a message into the multi-agent collaboration."""
for agent in self.agents:
agent.run(f"Name {name} and message: {message}")
self._step += 1
def inject_agent(self, agent: Flow):
"""Injects an agent into the multi-agent collaboration."""
self.agents.append(agent)
def step(self) -> tuple[str, str]:
"""Steps through the multi-agent collaboration."""
speaker_idx = self.select_next_speaker(self._step, self.agents)
speaker = self.agents[speaker_idx]
message = speaker.send()
@ -56,8 +116,16 @@ class MultiAgentCollaboration:
for receiver in self.agents:
receiver.receive(speaker.name, message)
self._step += 1
if self.logging:
self.log_step(speaker, message)
return speaker.name, message
def log_step(self, speaker: str, response: str):
"""Logs the step of the multi-agent collaboration."""
self.logger.info(f"{speaker.name}: {response}")
@tenacity.retry(
stop=tenacity.stop_after_attempt(10),
wait=tenacity.wait_none(),
@ -68,6 +136,7 @@ class MultiAgentCollaboration:
retry_error_callback=lambda retry_state: 0,
)
def ask_for_bid(self, agent) -> str:
"""Asks an agent for a bid."""
bid_string = agent.bid()
bid = int(bid_parser.parse(bid_string)["bid"])
return bid
@ -77,6 +146,7 @@ class MultiAgentCollaboration:
step: int,
agents,
) -> int:
"""Selects the next speaker."""
bids = []
for agent in agents:
bid = self.ask_for_bid(agent)
@ -86,15 +156,67 @@ class MultiAgentCollaboration:
idx = random.choice(max_indices)
return idx
def run(self, max_iters: int = 10):
@tenacity.retry(
stop=tenacity.stop_after_attempt(10),
wait=tenacity.wait_none(),
retry=tenacity.retry_if_exception_type(ValueError),
before_sleep=lambda retry_state: print(
f"ValueError occured: {retry_state.outcome.exception()}, retying..."
),
retry_error_callback=lambda retry_state: 0,
)
def run(self):
"""Runs the multi-agent collaboration."""
n = 0
self.reset()
self.inject("Debate Moderator")
print("(Debate Moderator): ")
print("\n")
while n < max_iters:
while n < self.max_iters:
name, message = self.step()
print(f"({name}): {message}")
print("\n")
n += 1
def format_results(self, results):
"""Formats the results of the run method"""
formatted_results = "\n".join(
[f"{result['agent']} responded: {result['response']}" for result in results]
)
return formatted_results
def save(self):
"""Saves the state of all agents."""
state = {
"step": self._step,
"results": [
{"agent": r["agent"].name, "response": r["response"]}
for r in self.results
],
}
with open(self.saved_file_path_name, "w") as file:
json.dump(state, file)
def load(self):
"""Loads the state of all agents."""
with open(self.saved_file_path_name, "r") as file:
state = json.load(file)
self._step = state["step"]
self.results = state["results"]
return state
def __repr__(self):
return f"MultiAgentCollaboration(agents={self.agents}, selection_function={self.select_next_speaker}, max_iters={self.max_iters}, autosave={self.autosave}, saved_file_path_name={self.saved_file_path_name})"
def performance(self):
"""Tracks and reports the performance of each agent"""
performance_data = {}
for agent in self.agents:
performance_data[agent.name] = agent.get_performance_metrics()
return performance_data
def set_interaction_rules(self, rules):
"""Sets the interaction rules for each agent"""
self.interaction_rules = rules

@ -1,76 +0,0 @@
from swarms.structs.flow import Flow
# Define a selection function
def select_speaker(step: int, agents) -> int:
# This function selects the speaker in a round-robin fashion
return step % len(agents)
class MultiAgentDebate:
"""
MultiAgentDebate
Args:
agents: Flow
selection_func: callable
max_iters: int
Usage:
>>> from swarms import MultiAgentDebate
>>> from swarms.structs.flow import Flow
>>> agents = Flow()
>>> agents.append(lambda x: x)
>>> agents.append(lambda x: x)
>>> agents.append(lambda x: x)
"""
def __init__(
self,
agents: Flow,
selection_func: callable = select_speaker,
max_iters: int = None,
):
self.agents = agents
self.selection_func = selection_func
self.max_iters = max_iters
def inject_agent(self, agent):
"""Injects an agent into the debate"""
self.agents.append(agent)
def run(
self,
task: str,
):
"""
MultiAgentDebate
Args:
task: str
Returns:
results: list
"""
results = []
for i in range(self.max_iters or len(self.agents)):
speaker_idx = self.selection_func(i, self.agents)
speaker = self.agents[speaker_idx]
response = speaker(task)
results.append({"response": response})
return results
def update_task(self, task: str):
"""Update the task"""
self.task = task
def format_results(self, results):
"""Format the results"""
formatted_results = "\n".join(
[f"Agent responded: {result['response']}" for result in results]
)
return formatted_results

@ -0,0 +1,154 @@
from enum import Enum, unique, auto
import abc
import hashlib
import re
from typing import List, Optional
import json
from dataclasses import dataclass, field
@unique
class LLMStatusCode(Enum):
SUCCESS = 0
ERROR = 1
@unique
class NodeType(Enum):
action = auto()
trigger = auto()
@unique
class WorkflowType(Enum):
Main = auto()
Sub = auto()
@unique
class ToolCallStatus(Enum):
ToolCallSuccess = auto()
ToolCallPartlySuccess = auto()
NoSuchTool = auto()
NoSuchFunction = auto()
InputCannotParsed = auto()
UndefinedParam = auto()
ParamTypeError = auto()
UnSupportedParam = auto()
UnsupportedExpression = auto()
ExpressionError = auto()
RequiredParamUnprovided = auto()
@unique
class TestDataType(Enum):
NoInput = auto()
TriggerInput = auto()
ActionInput = auto()
SubWorkflowInput = auto()
@unique
class RunTimeStatus(Enum):
FunctionExecuteSuccess = auto()
TriggerAcivatedSuccess = auto()
ErrorRaisedHere = auto()
ErrorRaisedInner = auto()
DidNotImplemented = auto()
DidNotBeenCalled = auto()
@dataclass
class TestResult:
"""
Responsible for handling the data structure of [{}]
"""
data_type: TestDataType = TestDataType.ActionInput
input_data: Optional[list] = field(default_factory=lambda: [])
runtime_status: RunTimeStatus = RunTimeStatus.DidNotBeenCalled
visit_times: int = 0
error_message: str = ""
output_data: Optional[list] = field(default_factory=lambda: [])
def load_from_json(self):
pass
def to_json(self):
pass
def to_str(self):
prompt = f"""
This function has been executed for {self.visit_times} times. Last execution:
1.Status: {self.runtime_status.name}
2.Input:
{self.input_data}
3.Output:
{self.output_data}"""
return prompt
@dataclass
class Action:
content: str = ""
thought: str = ""
plan: List[str] = field(default_factory=lambda: [])
criticism: str = ""
tool_name: str = ""
tool_input: dict = field(default_factory=lambda: {})
tool_output_status: ToolCallStatus = ToolCallStatus.ToolCallSuccess
tool_output: str = ""
def to_json(self):
try:
tool_output = json.loads(self.tool_output)
except:
tool_output = self.tool_output
return {
"thought": self.thought,
"plan": self.plan,
"criticism": self.criticism,
"tool_name": self.tool_name,
"tool_input": self.tool_input,
"tool_output_status": self.tool_output_status.name,
"tool_output": tool_output,
}
@dataclass
class userQuery:
task: str
additional_information: List[str] = field(default_factory=lambda: [])
refine_prompt: str = field(default_factory=lambda: "")
def print_self(self):
lines = [self.task]
for info in self.additional_information:
lines.append(f"- {info}")
return "\n".join(lines)
class Singleton(abc.ABCMeta, type):
"""
Singleton metaclass for ensuring only one instance of a class.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""Call method for the singleton metaclass."""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class AbstractSingleton(abc.ABC, metaclass=Singleton):
"""
Abstract singleton class for ensuring only one instance of a class.
"""

@ -33,8 +33,6 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
done (threading.Event): An event that is set when the subprocess is done running code.
Example:
>>> from swarms.utils.code_interpreter import SubprocessCodeInterpreter
"""
def __init__(self):
@ -89,7 +87,7 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
daemon=True,
).start()
def run(self, code):
def run(self, code: str):
retry_count = 0
max_retries = 3

@ -0,0 +1,501 @@
"""Logging modules"""
import logging
import os
import random
import re
import time
import json
from logging import LogRecord
from typing import Any
from colorama import Fore, Style
from swarms.utils.apa import Action, ToolCallStatus
# from autogpt.speech import say_text
class JsonFileHandler(logging.FileHandler):
def __init__(self, filename, mode="a", encoding=None, delay=False):
"""
Initializes a new instance of the class with the specified file name, mode, encoding, and delay settings.
Parameters:
filename (str): The name of the file to be opened.
mode (str, optional): The mode in which the file is opened. Defaults to "a" (append).
encoding (str, optional): The encoding used to read or write the file. Defaults to None.
delay (bool, optional): If True, the file opening is delayed until the first IO operation. Defaults to False.
Returns:
None
"""
super().__init__(filename, mode, encoding, delay)
def emit(self, record):
"""
Writes the formatted log record to a JSON file.
Parameters:
record (LogRecord): The log record to be emitted.
Returns:
None
"""
json_data = json.loads(self.format(record))
with open(self.baseFilename, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
class JsonFormatter(logging.Formatter):
def format(self, record):
"""
Format the given record and return the message.
Args:
record (object): The log record to be formatted.
Returns:
str: The formatted message from the record.
"""
return record.msg
class Logger:
"""
Logger that handle titles in different colors.
Outputs logs in console, activity.log, and errors.log
For console handler: simulates typing
"""
def __init__(self):
"""
Initializes the class and sets up the logging configuration.
Args:
None
Returns:
None
"""
# create log directory if it doesn't exist
this_files_dir_path = os.path.dirname(__file__)
log_dir = os.path.join(this_files_dir_path, "../logs")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = "activity.log"
error_file = "error.log"
console_formatter = AutoGptFormatter("%(title_color)s %(message)s")
# Create a handler for console which simulate typing
self.typing_console_handler = TypingConsoleHandler()
# self.typing_console_handler = ConsoleHandler()
self.typing_console_handler.setLevel(logging.INFO)
self.typing_console_handler.setFormatter(console_formatter)
# Create a handler for console without typing simulation
self.console_handler = ConsoleHandler()
self.console_handler.setLevel(logging.DEBUG)
self.console_handler.setFormatter(console_formatter)
# Info handler in activity.log
self.file_handler = logging.FileHandler(
os.path.join(log_dir, log_file), "a", "utf-8"
)
self.file_handler.setLevel(logging.DEBUG)
info_formatter = AutoGptFormatter(
"%(asctime)s %(levelname)s %(title)s %(message_no_color)s"
)
self.file_handler.setFormatter(info_formatter)
# Error handler error.log
error_handler = logging.FileHandler(
os.path.join(log_dir, error_file), "a", "utf-8"
)
error_handler.setLevel(logging.ERROR)
error_formatter = AutoGptFormatter(
"%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s"
" %(message_no_color)s"
)
error_handler.setFormatter(error_formatter)
self.typing_logger = logging.getLogger("TYPER")
self.typing_logger.addHandler(self.typing_console_handler)
# self.typing_logger.addHandler(self.console_handler)
self.typing_logger.addHandler(self.file_handler)
self.typing_logger.addHandler(error_handler)
self.typing_logger.setLevel(logging.DEBUG)
self.logger = logging.getLogger("LOGGER")
self.logger.addHandler(self.console_handler)
self.logger.addHandler(self.file_handler)
self.logger.addHandler(error_handler)
self.logger.setLevel(logging.DEBUG)
self.json_logger = logging.getLogger("JSON_LOGGER")
self.json_logger.addHandler(self.file_handler)
self.json_logger.addHandler(error_handler)
self.json_logger.setLevel(logging.DEBUG)
self.speak_mode = False
self.chat_plugins = []
def typewriter_log(
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
):
"""
Logs a message to the typewriter.
Args:
title (str, optional): The title of the log message. Defaults to "".
title_color (str, optional): The color of the title. Defaults to "".
content (str or list, optional): The content of the log message. Defaults to "".
speak_text (bool, optional): Whether to speak the log message. Defaults to False.
level (int, optional): The logging level of the message. Defaults to logging.INFO.
Returns:
None
"""
for plugin in self.chat_plugins:
plugin.report(f"{title}. {content}")
if content:
if isinstance(content, list):
content = " ".join(content)
else:
content = ""
self.typing_logger.log(
level, content, extra={"title": title, "color": title_color}
)
def debug(
self,
message,
title="",
title_color="",
):
"""
Logs a debug message.
Args:
message (str): The debug message to log.
title (str, optional): The title of the log message. Defaults to "".
title_color (str, optional): The color of the log message title. Defaults to "".
Returns:
None
"""
self._log(title, title_color, message, logging.DEBUG)
def info(
self,
message,
title="",
title_color="",
):
"""
Logs an informational message.
Args:
message (str): The message to be logged.
title (str, optional): The title of the log message. Defaults to "".
title_color (str, optional): The color of the log title. Defaults to "".
Returns:
None
"""
self._log(title, title_color, message, logging.INFO)
def warn(
self,
message,
title="",
title_color="",
):
"""
Logs a warning message.
Args:
message (str): The warning message.
title (str, optional): The title of the warning message. Defaults to "".
title_color (str, optional): The color of the title. Defaults to "".
"""
self._log(title, title_color, message, logging.WARN)
def error(self, title, message=""):
"""
Logs an error message with the given title and optional message.
Parameters:
title (str): The title of the error message.
message (str, optional): The optional additional message for the error. Defaults to an empty string.
"""
self._log(title, Fore.RED, message, logging.ERROR)
def _log(
self,
title: str = "",
title_color: str = "",
message: str = "",
level=logging.INFO,
):
"""
Logs a message with the given title and message at the specified log level.
Parameters:
title (str): The title of the log message. Defaults to an empty string.
title_color (str): The color of the log message title. Defaults to an empty string.
message (str): The log message. Defaults to an empty string.
level (int): The log level. Defaults to logging.INFO.
Returns:
None
"""
if message:
if isinstance(message, list):
message = " ".join(message)
self.logger.log(
level, message, extra={"title": str(title), "color": str(title_color)}
)
def set_level(self, level):
"""
Set the level of the logger and the typing_logger.
Args:
level: The level to set the logger to.
Returns:
None
"""
self.logger.setLevel(level)
self.typing_logger.setLevel(level)
def double_check(self, additionalText=None):
"""
A function that performs a double check on the configuration.
Parameters:
additionalText (str, optional): Additional text to be included in the double check message.
Returns:
None
"""
if not additionalText:
additionalText = (
"Please ensure you've setup and configured everything"
" correctly. Read https://github.com/Torantulino/Auto-GPT#readme to "
"double check. You can also create a github issue or join the discord"
" and ask there!"
)
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
def log_json(self, data: Any, file_name: str) -> None:
"""
Logs the given JSON data to a specified file.
Args:
data (Any): The JSON data to be logged.
file_name (str): The name of the file to log the data to.
Returns:
None: This function does not return anything.
"""
# Define log directory
this_files_dir_path = os.path.dirname(__file__)
log_dir = os.path.join(this_files_dir_path, "../logs")
# Create a handler for JSON files
json_file_path = os.path.join(log_dir, file_name)
json_data_handler = JsonFileHandler(json_file_path)
json_data_handler.setFormatter(JsonFormatter())
# Log the JSON data using the custom file handler
self.json_logger.addHandler(json_data_handler)
self.json_logger.debug(data)
self.json_logger.removeHandler(json_data_handler)
def get_log_directory(self):
"""
Returns the absolute path to the log directory.
Returns:
str: The absolute path to the log directory.
"""
this_files_dir_path = os.path.dirname(__file__)
log_dir = os.path.join(this_files_dir_path, "../logs")
return os.path.abspath(log_dir)
"""
Output stream to console using simulated typing
"""
class TypingConsoleHandler(logging.StreamHandler):
def emit(self, record):
"""
Emit a log record to the console with simulated typing effect.
Args:
record (LogRecord): The log record to be emitted.
Returns:
None
Raises:
Exception: If an error occurs while emitting the log record.
"""
min_typing_speed = 0.05
max_typing_speed = 0.10
# min_typing_speed = 0.005
# max_typing_speed = 0.010
msg = self.format(record)
try:
# replace enter & indent with other symbols
transfer_enter = "<ENTER>"
msg_transfered = str(msg).replace("\n", transfer_enter)
transfer_space = "<4SPACE>"
msg_transfered = str(msg_transfered).replace(" ", transfer_space)
words = msg_transfered.split()
words = [word.replace(transfer_enter, "\n") for word in words]
words = [word.replace(transfer_space, " ") for word in words]
for i, word in enumerate(words):
print(word, end="", flush=True)
if i < len(words) - 1:
print(" ", end="", flush=True)
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
time.sleep(typing_speed)
# type faster after each word
min_typing_speed = min_typing_speed * 0.95
max_typing_speed = max_typing_speed * 0.95
print()
except Exception:
self.handleError(record)
class ConsoleHandler(logging.StreamHandler):
def emit(self, record) -> None:
"""
Emit the log record.
Args:
record (logging.LogRecord): The log record to emit.
Returns:
None: This function does not return anything.
"""
msg = self.format(record)
try:
print(msg)
except Exception:
self.handleError(record)
class AutoGptFormatter(logging.Formatter):
"""
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
To use this formatter, make sure to pass 'color', 'title' as log extras.
"""
def format(self, record: LogRecord) -> str:
"""
Formats a log record into a string representation.
Args:
record (LogRecord): The log record to be formatted.
Returns:
str: The formatted log record as a string.
"""
if hasattr(record, "color"):
record.title_color = (
getattr(record, "color")
+ getattr(record, "title", "")
+ " "
+ Style.RESET_ALL
)
else:
record.title_color = getattr(record, "title", "")
# Add this line to set 'title' to an empty string if it doesn't exist
record.title = getattr(record, "title", "")
if hasattr(record, "msg"):
record.message_no_color = remove_color_codes(getattr(record, "msg"))
else:
record.message_no_color = ""
return super().format(record)
def remove_color_codes(s: str) -> str:
"""
Removes color codes from a given string.
Args:
s (str): The string from which to remove color codes.
Returns:
str: The string with color codes removed.
"""
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub("", s)
logger = Logger()
def print_action_base(action: Action):
"""
Print the different properties of an Action object.
Parameters:
action (Action): The Action object to print.
Returns:
None
"""
if action.content != "":
logger.typewriter_log(f"content:", Fore.YELLOW, f"{action.content}")
logger.typewriter_log(f"Thought:", Fore.YELLOW, f"{action.thought}")
if len(action.plan) > 0:
logger.typewriter_log(
f"Plan:",
Fore.YELLOW,
)
for line in action.plan:
line = line.lstrip("- ")
logger.typewriter_log("- ", Fore.GREEN, line.strip())
logger.typewriter_log(f"Criticism:", Fore.YELLOW, f"{action.criticism}")
def print_action_tool(action: Action):
"""
Prints the details of an action tool.
Args:
action (Action): The action object containing the tool details.
Returns:
None
"""
logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}")
logger.typewriter_log(f"Tool Input:", Fore.BLUE, f"{action.tool_input}")
output = action.tool_output if action.tool_output != "" else "None"
logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}")
color = Fore.RED
if action.tool_output_status == ToolCallStatus.ToolCallSuccess:
color = Fore.GREEN
elif action.tool_output_status == ToolCallStatus.InputCannotParsed:
color = Fore.YELLOW
logger.typewriter_log(
f"Tool Call Status:",
Fore.BLUE,
f"{color}{action.tool_output_status.name}{Style.RESET_ALL}",
)

@ -0,0 +1,80 @@
import pytest
from unittest.mock import patch, mock_open
from swarms.models.eleven_labs import ElevenLabsText2SpeechTool, ElevenLabsModel
import os
from dotenv import load_dotenv
load_dotenv()
# Define some test data
SAMPLE_TEXT = "Hello, this is a test."
API_KEY = os.environ.get("ELEVEN_API_KEY")
EXPECTED_SPEECH_FILE = "expected_speech.wav"
@pytest.fixture
def eleven_labs_tool():
return ElevenLabsText2SpeechTool()
# Basic functionality tests
def test_run_text_to_speech(eleven_labs_tool):
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
assert isinstance(speech_file, str)
assert speech_file.endswith(".wav")
def test_play_speech(eleven_labs_tool):
with patch("builtins.open", mock_open(read_data="fake_audio_data")):
eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
def test_stream_speech(eleven_labs_tool):
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file:
eleven_labs_tool.stream_speech(SAMPLE_TEXT)
mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False)
# Testing fixture and environment variables
def test_api_key_validation(eleven_labs_tool):
with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY):
values = {"eleven_api_key": None}
validated_values = eleven_labs_tool.validate_environment(values)
assert "eleven_api_key" in validated_values
# Mocking the external library
def test_run_text_to_speech_with_mock(eleven_labs_tool):
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch(
"your_module._import_elevenlabs"
) as mock_elevenlabs:
mock_elevenlabs_instance = mock_elevenlabs.return_value
mock_elevenlabs_instance.generate.return_value = b"fake_audio_data"
eleven_labs_tool.run(SAMPLE_TEXT)
assert mock_file.call_args[1]["suffix"] == ".wav"
assert mock_file.call_args[1]["delete"] is False
assert mock_file().write.call_args[0][0] == b"fake_audio_data"
# Exception testing
def test_run_text_to_speech_error_handling(eleven_labs_tool):
with patch("your_module._import_elevenlabs") as mock_elevenlabs:
mock_elevenlabs_instance = mock_elevenlabs.return_value
mock_elevenlabs_instance.generate.side_effect = Exception("Test Exception")
with pytest.raises(
RuntimeError,
match="Error while running ElevenLabsText2SpeechTool: Test Exception",
):
eleven_labs_tool.run(SAMPLE_TEXT)
# Parameterized testing
@pytest.mark.parametrize(
"model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL]
)
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model):
eleven_labs_tool.model = model
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
assert isinstance(speech_file, str)
assert speech_file.endswith(".wav")

@ -0,0 +1,26 @@
from swarms.models import __all__
EXPECTED_ALL = [
"Anthropic",
"Petals",
"Mistral",
"OpenAI",
"AzureOpenAI",
"OpenAIChat",
"Zephyr",
"Idefics",
# "Kosmos",
"Vilt",
"Nougat",
"LayoutLMDocumentQA",
"BioGPT",
"HuggingfaceLLM",
"MPT7B",
"WizardLLMStoryTeller",
# "GPT4Vision",
# "Dalle3",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

@ -1,5 +1,5 @@
from unittest.mock import patch
from swarms.swarms.autoscaler import AutoScaler
from swarms.structs.autoscaler import AutoScaler
from swarms.models import OpenAIChat
from swarms.structs import Flow

@ -1,76 +1,168 @@
from unittest.mock import patch
import json
import os
import pytest
from unittest.mock import Mock
from swarms.structs import Flow
from swarms.models import OpenAIChat
from swarms.swarms.multi_agent_collab import (
MultiAgentCollaboration,
Worker,
select_next_speaker,
select_next_speaker_director,
select_speaker_round_table,
)
# Sample agents for testing
agent1 = Flow(llm=OpenAIChat(), max_loops=2)
agent2 = Flow(llm=OpenAIChat(), max_loops=2)
agents = [agent1, agent2]
def test_multiagentcollaboration_initialization():
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
assert isinstance(multiagentcollaboration, MultiAgentCollaboration)
assert len(multiagentcollaboration.agents) == 5
assert multiagentcollaboration._step == 0
@patch("swarms.workers.Worker.reset")
def test_multiagentcollaboration_reset(mock_reset):
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
multiagentcollaboration.reset()
assert mock_reset.call_count == 5
@patch("swarms.workers.Worker.run")
def test_multiagentcollaboration_inject(mock_run):
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
multiagentcollaboration.inject("Agent 1", "Hello, world!")
assert multiagentcollaboration._step == 1
assert mock_run.call_count == 5
@patch("swarms.workers.Worker.send")
@patch("swarms.workers.Worker.receive")
def test_multiagentcollaboration_step(mock_receive, mock_send):
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
multiagentcollaboration.step()
assert multiagentcollaboration._step == 1
assert mock_send.call_count == 5
assert mock_receive.call_count == 25
@patch("swarms.workers.Worker.bid")
def test_multiagentcollaboration_ask_for_bid(mock_bid):
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
result = multiagentcollaboration.ask_for_bid(Worker)
assert isinstance(result, int)
@patch("swarms.workers.Worker.bid")
def test_multiagentcollaboration_select_next_speaker(mock_bid):
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
result = multiagentcollaboration.select_next_speaker(1, [Worker] * 5)
assert isinstance(result, int)
@patch("swarms.workers.Worker.send")
@patch("swarms.workers.Worker.receive")
def test_multiagentcollaboration_run(mock_receive, mock_send):
multiagentcollaboration = MultiAgentCollaboration(
agents=[Worker] * 5, selection_function=select_next_speaker
)
multiagentcollaboration.run(max_iters=5)
assert multiagentcollaboration._step == 6
assert mock_send.call_count == 30
assert mock_receive.call_count == 150
@pytest.fixture
def collaboration():
return MultiAgentCollaboration(agents)
def test_collaboration_initialization(collaboration):
assert len(collaboration.agents) == 2
assert callable(collaboration.select_next_speaker)
assert collaboration.max_iters == 10
assert collaboration.results == []
assert collaboration.logging == True
def test_reset(collaboration):
collaboration.reset()
for agent in collaboration.agents:
assert agent.step == 0
def test_inject(collaboration):
collaboration.inject("TestName", "TestMessage")
for agent in collaboration.agents:
assert "TestName" in agent.history[-1]
assert "TestMessage" in agent.history[-1]
def test_inject_agent(collaboration):
agent3 = Flow(llm=OpenAIChat(), max_loops=2)
collaboration.inject_agent(agent3)
assert len(collaboration.agents) == 3
assert agent3 in collaboration.agents
def test_step(collaboration):
collaboration.step()
for agent in collaboration.agents:
assert agent.step == 1
def test_ask_for_bid(collaboration):
agent = Mock()
agent.bid.return_value = "<5>"
bid = collaboration.ask_for_bid(agent)
assert bid == 5
def test_select_next_speaker(collaboration):
collaboration.select_next_speaker = Mock(return_value=0)
idx = collaboration.select_next_speaker(1, collaboration.agents)
assert idx == 0
def test_run(collaboration):
collaboration.run()
for agent in collaboration.agents:
assert agent.step == collaboration.max_iters
def test_format_results(collaboration):
collaboration.results = [{"agent": "Agent1", "response": "Response1"}]
formatted_results = collaboration.format_results(collaboration.results)
assert "Agent1 responded: Response1" in formatted_results
def test_save_and_load(collaboration):
collaboration.save()
loaded_state = collaboration.load()
assert loaded_state["_step"] == collaboration._step
assert loaded_state["results"] == collaboration.results
def test_performance(collaboration):
performance_data = collaboration.performance()
for agent in collaboration.agents:
assert agent.name in performance_data
assert "metrics" in performance_data[agent.name]
def test_set_interaction_rules(collaboration):
rules = {"rule1": "action1", "rule2": "action2"}
collaboration.set_interaction_rules(rules)
assert hasattr(collaboration, "interaction_rules")
assert collaboration.interaction_rules == rules
def test_set_interaction_rules(collaboration):
rules = {"rule1": "action1", "rule2": "action2"}
collaboration.set_interaction_rules(rules)
assert hasattr(collaboration, "interaction_rules")
assert collaboration.interaction_rules == rules
def test_repr(collaboration):
repr_str = repr(collaboration)
assert isinstance(repr_str, str)
assert "MultiAgentCollaboration" in repr_str
def test_load(collaboration):
state = {"step": 5, "results": [{"agent": "Agent1", "response": "Response1"}]}
with open(collaboration.saved_file_path_name, "w") as file:
json.dump(state, file)
loaded_state = collaboration.load()
assert loaded_state["_step"] == state["step"]
assert loaded_state["results"] == state["results"]
def test_save(collaboration, tmp_path):
collaboration.saved_file_path_name = tmp_path / "test_save.json"
collaboration.save()
with open(collaboration.saved_file_path_name, "r") as file:
saved_data = json.load(file)
assert saved_data["_step"] == collaboration._step
assert saved_data["results"] == collaboration.results
# Add more tests here...
# Example of parameterized test for different selection functions
@pytest.mark.parametrize(
"selection_function", [select_next_speaker_director, select_speaker_round_table]
)
def test_selection_functions(collaboration, selection_function):
collaboration.select_next_speaker = selection_function
assert callable(collaboration.select_next_speaker)
# Add more parameterized tests for different scenarios...
# Example of exception testing
def test_exception_handling(collaboration):
agent = Mock()
agent.bid.side_effect = ValueError("Invalid bid")
with pytest.raises(ValueError):
collaboration.ask_for_bid(agent)
# Add more exception testing...
# Example of environment variable testing (if applicable)
@pytest.mark.parametrize("env_var", ["ENV_VAR_1", "ENV_VAR_2"])
def test_environment_variables(collaboration, monkeypatch, env_var):
monkeypatch.setenv(env_var, "test_value")
assert os.getenv(env_var) == "test_value"

Loading…
Cancel
Save