multi agent debate -> multi agent collab + reset method in flow + updated tests for multiagentcollab, code interpreter fix in flow, loggers, eleven lab, tests, and docs
parent
954e580af0
commit
fddc12f828
@ -0,0 +1,82 @@
|
||||
# ElevenLabsText2SpeechTool Documentation
|
||||
|
||||
## Table of Contents
|
||||
1. [Introduction](#introduction)
|
||||
2. [Class Overview](#class-overview)
|
||||
- [Attributes](#attributes)
|
||||
3. [Installation](#installation)
|
||||
4. [Usage](#usage)
|
||||
- [Initialization](#initialization)
|
||||
- [Converting Text to Speech](#converting-text-to-speech)
|
||||
- [Playing and Streaming Speech](#playing-and-streaming-speech)
|
||||
5. [Exception Handling](#exception-handling)
|
||||
6. [Advanced Usage](#advanced-usage)
|
||||
7. [Contributing](#contributing)
|
||||
8. [References](#references)
|
||||
|
||||
## 1. Introduction <a name="introduction"></a>
|
||||
The `ElevenLabsText2SpeechTool` is a Python class designed to simplify the process of converting text to speech using the Eleven Labs Text2Speech API. This tool is a wrapper around the API and provides a convenient interface for generating speech from text. It supports multiple languages, making it suitable for a wide range of applications, including voice assistants, audio content generation, and more.
|
||||
|
||||
## 2. Class Overview <a name="class-overview"></a>
|
||||
### Attributes <a name="attributes"></a>
|
||||
- `model` (Union[ElevenLabsModel, str]): The model to use for text to speech. Defaults to `ElevenLabsModel.MULTI_LINGUAL`.
|
||||
- `name` (str): The name of the tool. Defaults to `"eleven_labs_text2speech"`.
|
||||
- `description` (str): A brief description of the tool. Defaults to a detailed explanation of its functionality.
|
||||
|
||||
## 3. Installation <a name="installation"></a>
|
||||
To use the `ElevenLabsText2SpeechTool`, you need to install the required dependencies and have access to the Eleven Labs Text2Speech API. Follow these steps:
|
||||
|
||||
1. Install the `elevenlabs` library:
|
||||
```
|
||||
pip install elevenlabs
|
||||
```
|
||||
|
||||
2. Install the `swarms` library
|
||||
`pip install swarms`
|
||||
|
||||
3. Set up your API key by following the instructions at [Eleven Labs Documentation](https://docs.elevenlabs.io/welcome/introduction).
|
||||
|
||||
## 4. Usage <a name="usage"></a>
|
||||
### Initialization <a name="initialization"></a>
|
||||
To get started, create an instance of the `ElevenLabsText2SpeechTool`. You can customize the `model` attribute if needed.
|
||||
|
||||
```python
|
||||
from swarms.models import ElevenLabsText2SpeechTool
|
||||
|
||||
stt = ElevenLabsText2SpeechTool(model=ElevenLabsModel.MONO_LINGUAL)
|
||||
```
|
||||
|
||||
### Converting Text to Speech <a name="converting-text-to-speech"></a>
|
||||
You can use the `run` method to convert text to speech. It returns the path to the generated speech file.
|
||||
|
||||
```python
|
||||
speech_file = stt.run("Hello, this is a test.")
|
||||
```
|
||||
|
||||
### Playing and Streaming Speech <a name="playing-and-streaming-speech"></a>
|
||||
- Use the `play` method to play the generated speech file.
|
||||
|
||||
```python
|
||||
stt.play(speech_file)
|
||||
```
|
||||
|
||||
- Use the `stream_speech` method to stream the text as speech. It plays the speech in real-time.
|
||||
|
||||
```python
|
||||
stt.stream_speech("Hello world!")
|
||||
```
|
||||
|
||||
## 5. Exception Handling <a name="exception-handling"></a>
|
||||
The `ElevenLabsText2SpeechTool` handles exceptions gracefully. If an error occurs during the conversion process, it raises a `RuntimeError` with an informative error message.
|
||||
|
||||
## 6. Advanced Usage <a name="advanced-usage"></a>
|
||||
- You can implement custom error handling and logging to further enhance the functionality of this tool.
|
||||
- For advanced users, extending the class to support additional features or customization is possible.
|
||||
|
||||
## 7. Contributing <a name="contributing"></a>
|
||||
Contributions to this tool are welcome. Feel free to open issues, submit pull requests, or provide feedback to improve its functionality and documentation.
|
||||
|
||||
## 8. References <a name="references"></a>
|
||||
- [Eleven Labs Text2Speech API Documentation](https://docs.elevenlabs.io/welcome/introduction)
|
||||
|
||||
This documentation provides a comprehensive guide to using the `ElevenLabsText2SpeechTool`. It covers installation, basic usage, advanced features, and contribution guidelines. Refer to the [References](#references) section for additional resources.
|
@ -1,517 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
import zipfile
|
||||
from tempfile import mkdtemp
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyautogui
|
||||
import requests
|
||||
import semver
|
||||
import undetected_chromedriver as uc # type: ignore
|
||||
import yaml
|
||||
from extension import load_extension
|
||||
from pydantic import BaseModel
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.webdriver.remote.webelement import WebElement
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _is_blank_agent(agent_name: str) -> bool:
|
||||
with open(f"agents/{agent_name}.py", "r") as agent_file:
|
||||
agent_data = agent_file.read()
|
||||
with open("src/template.py", "r") as template_file:
|
||||
template_data = template_file.read()
|
||||
return agent_data == template_data
|
||||
|
||||
|
||||
def record(agent_name: str, autotab_ext_path: Optional[str] = None):
|
||||
if not os.path.exists("agents"):
|
||||
os.makedirs("agents")
|
||||
|
||||
if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local":
|
||||
if not _is_blank_agent(agent_name=agent_name):
|
||||
raise Exception(f"Agent with name {agent_name} already exists")
|
||||
driver = get_driver( # noqa: F841
|
||||
autotab_ext_path=autotab_ext_path,
|
||||
record_mode=True,
|
||||
)
|
||||
# Need to keep a reference to the driver so that it doesn't get garbage collected
|
||||
with open("src/template.py", "r") as file:
|
||||
data = file.read()
|
||||
|
||||
with open(f"agents/{agent_name}.py", "w") as file:
|
||||
file.write(data)
|
||||
|
||||
print(
|
||||
"\033[34mYou have the Python debugger open, you can run commands in it like you"
|
||||
" would in a normal Python shell.\033[0m"
|
||||
)
|
||||
print(
|
||||
"\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and"
|
||||
" press enter.\033[0m"
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
record("agent")
|
||||
|
||||
|
||||
def extract_domain_from_url(url: str):
|
||||
# url = http://username:password@hostname:port/path?arg=value#anchor
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
if hostname is None:
|
||||
raise ValueError(f"Could not extract hostname from url {url}")
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
return hostname
|
||||
|
||||
|
||||
class AutotabChromeDriver(uc.Chrome):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def find_element_with_retry(
|
||||
self, by=By.ID, value: Optional[str] = None
|
||||
) -> WebElement:
|
||||
try:
|
||||
return super().find_element(by, value)
|
||||
except Exception as e:
|
||||
# TODO: Use an LLM to retry, finding a similar element on the DOM
|
||||
breakpoint()
|
||||
raise e
|
||||
|
||||
|
||||
def open_plugin(driver: AutotabChromeDriver):
|
||||
print("Opening plugin sidepanel")
|
||||
driver.execute_script("document.activeElement.blur();")
|
||||
pyautogui.press("esc")
|
||||
pyautogui.hotkey("command", "shift", "y", interval=0.05) # mypy: ignore
|
||||
|
||||
|
||||
def open_plugin_and_login(driver: AutotabChromeDriver):
|
||||
if config.autotab_api_key is not None:
|
||||
backend_url = (
|
||||
"http://localhost:8000"
|
||||
if config.environment == "local"
|
||||
else "https://api.autotab.com"
|
||||
)
|
||||
driver.get(f"{backend_url}/auth/signin-api-key-page")
|
||||
response = requests.post(
|
||||
f"{backend_url}/auth/signin-api-key",
|
||||
json={"api_key": config.autotab_api_key},
|
||||
)
|
||||
cookie = response.json()
|
||||
if response.status_code != 200:
|
||||
if response.status_code == 401:
|
||||
raise Exception("Invalid API key")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error {response.status_code} from backend while logging you in"
|
||||
f" with your API key: {response.text}"
|
||||
)
|
||||
cookie["name"] = cookie["key"]
|
||||
del cookie["key"]
|
||||
driver.add_cookie(cookie)
|
||||
|
||||
driver.get("https://www.google.com")
|
||||
open_plugin(driver)
|
||||
else:
|
||||
print("No autotab API key found, heading to autotab.com to sign up")
|
||||
|
||||
url = (
|
||||
"http://localhost:3000/dashboard"
|
||||
if config.environment == "local"
|
||||
else "https://autotab.com/dashboard"
|
||||
)
|
||||
driver.get(url)
|
||||
time.sleep(0.5)
|
||||
|
||||
open_plugin(driver)
|
||||
|
||||
|
||||
def get_driver(
|
||||
autotab_ext_path: Optional[str] = None, record_mode: bool = False
|
||||
) -> AutotabChromeDriver:
|
||||
options = webdriver.ChromeOptions()
|
||||
options.add_argument("--no-sandbox") # Necessary for running
|
||||
options.add_argument(
|
||||
"--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
" (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36"
|
||||
)
|
||||
options.add_argument("--enable-webgl")
|
||||
options.add_argument("--enable-3d-apis")
|
||||
options.add_argument("--enable-clipboard-read-write")
|
||||
options.add_argument("--disable-popup-blocking")
|
||||
|
||||
if autotab_ext_path is None:
|
||||
load_extension()
|
||||
options.add_argument("--load-extension=./src/extension/autotab")
|
||||
else:
|
||||
options.add_argument(f"--load-extension={autotab_ext_path}")
|
||||
|
||||
options.add_argument("--allow-running-insecure-content")
|
||||
options.add_argument("--disable-web-security")
|
||||
options.add_argument(f"--user-data-dir={mkdtemp()}")
|
||||
options.binary_location = config.chrome_binary_location
|
||||
driver = AutotabChromeDriver(options=options)
|
||||
if record_mode:
|
||||
open_plugin_and_login(driver)
|
||||
|
||||
return driver
|
||||
|
||||
|
||||
class SiteCredentials(BaseModel):
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
login_with_google_account: Optional[str] = None
|
||||
login_url: Optional[str] = None
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
if self.name is None:
|
||||
self.name = self.email
|
||||
|
||||
|
||||
class GoogleCredentials(BaseModel):
|
||||
credentials: Dict[str, SiteCredentials]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
for cred in self.credentials.values():
|
||||
cred.login_url = "https://accounts.google.com/v3/signin"
|
||||
|
||||
@property
|
||||
def default(self) -> SiteCredentials:
|
||||
if "default" not in self.credentials:
|
||||
if len(self.credentials) == 1:
|
||||
return list(self.credentials.values())[0]
|
||||
raise Exception("No default credentials found in config")
|
||||
return self.credentials["default"]
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
autotab_api_key: Optional[str]
|
||||
credentials: Dict[str, SiteCredentials]
|
||||
google_credentials: GoogleCredentials
|
||||
chrome_binary_location: str
|
||||
environment: str
|
||||
|
||||
@classmethod
|
||||
def load_from_yaml(cls, path: str):
|
||||
with open(path, "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
_credentials = {}
|
||||
for domain, creds in config.get("credentials", {}).items():
|
||||
if "login_url" not in creds:
|
||||
creds["login_url"] = f"https://{domain}/login"
|
||||
site_creds = SiteCredentials(**creds)
|
||||
_credentials[domain] = site_creds
|
||||
for alt in creds.get("alts", []):
|
||||
_credentials[alt] = site_creds
|
||||
|
||||
google_credentials = {}
|
||||
for creds in config.get("google_credentials", []):
|
||||
credentials: SiteCredentials = SiteCredentials(**creds)
|
||||
google_credentials[credentials.name] = credentials
|
||||
|
||||
chrome_binary_location = config.get("chrome_binary_location")
|
||||
if chrome_binary_location is None:
|
||||
raise Exception("Must specify chrome_binary_location in config")
|
||||
|
||||
autotab_api_key = config.get("autotab_api_key")
|
||||
if autotab_api_key == "...":
|
||||
autotab_api_key = None
|
||||
|
||||
return cls(
|
||||
autotab_api_key=autotab_api_key,
|
||||
credentials=_credentials,
|
||||
google_credentials=GoogleCredentials(credentials=google_credentials),
|
||||
chrome_binary_location=config.get("chrome_binary_location"),
|
||||
environment=config.get("environment", "prod"),
|
||||
)
|
||||
|
||||
def get_site_credentials(self, domain: str) -> SiteCredentials:
|
||||
credentials = self.credentials[domain].copy()
|
||||
return credentials
|
||||
|
||||
|
||||
config = Config.load_from_yaml(".autotab.yaml")
|
||||
|
||||
|
||||
def is_signed_in_to_google(driver):
|
||||
cookies = driver.get_cookies()
|
||||
return len([c for c in cookies if c["name"] == "SAPISID"]) != 0
|
||||
|
||||
|
||||
def google_login(
|
||||
driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True
|
||||
):
|
||||
print("Logging in to Google")
|
||||
if navigate:
|
||||
driver.get("https://accounts.google.com/")
|
||||
time.sleep(1)
|
||||
if is_signed_in_to_google(driver):
|
||||
print("Already signed in to Google")
|
||||
return
|
||||
|
||||
if os.path.exists("google_cookies.json"):
|
||||
print("cookies exist, doing loading")
|
||||
with open("google_cookies.json", "r") as f:
|
||||
google_cookies = json.load(f)
|
||||
for cookie in google_cookies:
|
||||
if "expiry" in cookie:
|
||||
cookie["expires"] = cookie["expiry"]
|
||||
del cookie["expiry"]
|
||||
driver.execute_cdp_cmd("Network.setCookie", cookie)
|
||||
time.sleep(1)
|
||||
driver.refresh()
|
||||
time.sleep(2)
|
||||
|
||||
if not credentials:
|
||||
credentials = config.google_credentials.default
|
||||
|
||||
if credentials is None:
|
||||
raise Exception("No credentials provided for Google login")
|
||||
|
||||
email_input = driver.find_element(By.CSS_SELECTOR, "[type='email']")
|
||||
email_input.send_keys(credentials.email)
|
||||
email_input.send_keys(Keys.ENTER)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))
|
||||
)
|
||||
|
||||
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
|
||||
password_input.send_keys(credentials.password)
|
||||
password_input.send_keys(Keys.ENTER)
|
||||
time.sleep(1.5)
|
||||
print("Successfully logged in to Google")
|
||||
|
||||
cookies = driver.get_cookies()
|
||||
if not is_signed_in_to_google(driver):
|
||||
# Probably wanted to have us solve a captcha, or 2FA or confirm recovery details
|
||||
print("Need 2FA help to log in to Google")
|
||||
# TODO: Show screenshot it to the user
|
||||
breakpoint()
|
||||
|
||||
if not os.path.exists("google_cookies.json"):
|
||||
print("Setting Google cookies for future use")
|
||||
# Log out to have access to the right cookies
|
||||
driver.get("https://accounts.google.com/Logout")
|
||||
time.sleep(2)
|
||||
cookies = driver.get_cookies()
|
||||
cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"]
|
||||
google_cookies = [
|
||||
cookie
|
||||
for cookie in cookies
|
||||
if cookie["domain"] in [".google.com", "accounts.google.com"]
|
||||
and cookie["name"] in cookie_names
|
||||
]
|
||||
with open("google_cookies.json", "w") as f:
|
||||
json.dump(google_cookies, f)
|
||||
|
||||
# Log back in
|
||||
login_button = driver.find_element(
|
||||
By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']"
|
||||
)
|
||||
login_button.click()
|
||||
time.sleep(1)
|
||||
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
|
||||
password_input.send_keys(credentials.password)
|
||||
password_input.send_keys(Keys.ENTER)
|
||||
|
||||
time.sleep(3)
|
||||
print("Successfully copied Google cookies for the future")
|
||||
|
||||
|
||||
def login(driver, url: str):
|
||||
domain = extract_domain_from_url(url)
|
||||
|
||||
credentials = config.get_site_credentials(domain)
|
||||
login_url = credentials.login_url
|
||||
if credentials.login_with_google_account:
|
||||
google_credentials = config.google_credentials.credentials[
|
||||
credentials.login_with_google_account
|
||||
]
|
||||
_login_with_google(driver, login_url, google_credentials)
|
||||
else:
|
||||
_login(driver, login_url, credentials=credentials)
|
||||
|
||||
|
||||
def _login(driver, url: str, credentials: SiteCredentials):
|
||||
print(f"Logging in to {url}")
|
||||
driver.get(url)
|
||||
time.sleep(2)
|
||||
email_input = driver.find_element(By.NAME, "email")
|
||||
email_input.send_keys(credentials.email)
|
||||
password_input = driver.find_element(By.NAME, "password")
|
||||
password_input.send_keys(credentials.password)
|
||||
password_input.send_keys(Keys.ENTER)
|
||||
|
||||
time.sleep(3)
|
||||
print(f"Successfully logged in to {url}")
|
||||
|
||||
|
||||
def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
|
||||
print(f"Logging in to {url} with Google")
|
||||
|
||||
google_login(driver, credentials=google_credentials)
|
||||
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
main_window = driver.current_window_handle
|
||||
xpath = (
|
||||
"//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with"
|
||||
" Google') or contains(@title, 'Sign in with Google')]"
|
||||
)
|
||||
|
||||
WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath)))
|
||||
driver.find_element(
|
||||
By.XPATH,
|
||||
xpath,
|
||||
).click()
|
||||
|
||||
driver.switch_to.window(driver.window_handles[-1])
|
||||
driver.find_element(
|
||||
By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]"
|
||||
).click()
|
||||
|
||||
driver.switch_to.window(main_window)
|
||||
|
||||
time.sleep(5)
|
||||
print(f"Successfully logged in to {url}")
|
||||
|
||||
|
||||
def update():
|
||||
print("updating extension...")
|
||||
# Download the autotab.crx file
|
||||
response = requests.get(
|
||||
"https://github.com/Planetary-Computers/autotab-extension/raw/main/autotab.crx",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Check if the directory exists, if not create it
|
||||
if os.path.exists("src/extension/.autotab"):
|
||||
shutil.rmtree("src/extension/.autotab")
|
||||
os.makedirs("src/extension/.autotab")
|
||||
|
||||
# Open the file in write binary mode
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
t = tqdm(total=total_size, unit="iB", unit_scale=True)
|
||||
with open("src/extension/.autotab/autotab.crx", "wb") as f:
|
||||
for data in response.iter_content(block_size):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
t.close()
|
||||
if total_size != 0 and t.n != total_size:
|
||||
print("ERROR, something went wrong")
|
||||
|
||||
# Unzip the file
|
||||
with zipfile.ZipFile("src/extension/.autotab/autotab.crx", "r") as zip_ref:
|
||||
zip_ref.extractall("src/extension/.autotab")
|
||||
os.remove("src/extension/.autotab/autotab.crx")
|
||||
if os.path.exists("src/extension/autotab"):
|
||||
shutil.rmtree("src/extension/autotab")
|
||||
os.rename("src/extension/.autotab", "src/extension/autotab")
|
||||
|
||||
|
||||
def should_update():
|
||||
if not os.path.exists("src/extension/autotab"):
|
||||
return True
|
||||
# Fetch the XML file
|
||||
response = requests.get(
|
||||
"https://raw.githubusercontent.com/Planetary-Computers/autotab-extension/main/update.xml"
|
||||
)
|
||||
xml_content = response.content
|
||||
|
||||
# Parse the XML file
|
||||
root = ET.fromstring(xml_content)
|
||||
namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces
|
||||
xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version")
|
||||
|
||||
# Load the local JSON file
|
||||
with open("src/extension/autotab/manifest.json", "r") as f:
|
||||
json_content = json.load(f)
|
||||
json_version = json_content["version"]
|
||||
# Compare versions
|
||||
return semver.compare(xml_version, json_version) > 0
|
||||
|
||||
|
||||
def load_extension():
|
||||
should_update() and update()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("should update:", should_update())
|
||||
update()
|
||||
|
||||
|
||||
def play(agent_name: Optional[str] = None):
|
||||
if agent_name is None:
|
||||
agent_files = os.listdir("agents")
|
||||
if len(agent_files) == 0:
|
||||
raise Exception("No agents found in agents/ directory")
|
||||
elif len(agent_files) == 1:
|
||||
agent_file = agent_files[0]
|
||||
else:
|
||||
print("Found multiple agent files, please select one:")
|
||||
for i, file in enumerate(agent_files, start=1):
|
||||
print(f"{i}. {file}")
|
||||
|
||||
selected = int(input("Select a file by number: ")) - 1
|
||||
agent_file = agent_files[selected]
|
||||
else:
|
||||
agent_file = f"{agent_name}.py"
|
||||
|
||||
os.system(f"python agents/{agent_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
play()
|
||||
"""
|
||||
|
||||
|
||||
chrome_binary_location: /Applications/Google Chrome.app/Contents/MacOS/Google Chrome
|
||||
|
||||
autotab_api_key: ... # Go to https://autotab.com/dashboard to get your API key, or
|
||||
# run `autotab record` with this field blank and you will be prompted to log in to autotab
|
||||
|
||||
# Optional, programmatically login to services using "Login with Google" authentication
|
||||
google_credentials:
|
||||
- name: default
|
||||
email: ...
|
||||
password: ...
|
||||
|
||||
# Optional, specify alternative accounts to use with Google login on a per-service basis
|
||||
- email: you@gmail.com # Credentials without a name use email as key
|
||||
password: ...
|
||||
|
||||
credentials:
|
||||
notion.so:
|
||||
alts:
|
||||
- notion.com
|
||||
login_with_google_account: default
|
||||
|
||||
figma.com:
|
||||
email: ...
|
||||
password: ...
|
||||
|
||||
airtable.com:
|
||||
login_with_google_account: you@gmail.com
|
||||
"""
|
@ -0,0 +1,104 @@
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
from swarms.tools.tool import BaseTool
|
||||
|
||||
|
||||
def _import_elevenlabs() -> Any:
|
||||
try:
|
||||
import elevenlabs
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import elevenlabs, please install `pip install elevenlabs`."
|
||||
) from e
|
||||
return elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsModel(str, Enum):
|
||||
"""Models available for Eleven Labs Text2Speech."""
|
||||
|
||||
MULTI_LINGUAL = "eleven_multilingual_v1"
|
||||
MONO_LINGUAL = "eleven_monolingual_v1"
|
||||
|
||||
|
||||
class ElevenLabsText2SpeechTool(BaseTool):
|
||||
"""Tool that queries the Eleven Labs Text2Speech API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://docs.elevenlabs.io/welcome/introduction
|
||||
|
||||
Attributes:
|
||||
model (ElevenLabsModel): The model to use for text to speech.
|
||||
Defaults to ElevenLabsModel.MULTI_LINGUAL.
|
||||
name (str): The name of the tool. Defaults to "eleven_labs_text2speech".
|
||||
description (str): The description of the tool.
|
||||
Defaults to "A wrapper around Eleven Labs Text2Speech. Useful for when you need to convert text to speech. It supports multiple languages, including English, German, Polish, Spanish, Italian, French, Portuguese, and Hindi."
|
||||
|
||||
|
||||
Usage:
|
||||
>>> from swarms.models import ElevenLabsText2SpeechTool
|
||||
>>> stt = ElevenLabsText2SpeechTool()
|
||||
>>> speech_file = stt.run("Hello world!")
|
||||
>>> stt.play(speech_file)
|
||||
>>> stt.stream_speech("Hello world!")
|
||||
|
||||
"""
|
||||
|
||||
model: Union[ElevenLabsModel, str] = ElevenLabsModel.MULTI_LINGUAL
|
||||
|
||||
name: str = "eleven_labs_text2speech"
|
||||
description: str = (
|
||||
"A wrapper around Eleven Labs Text2Speech. "
|
||||
"Useful for when you need to convert text to speech. "
|
||||
"It supports multiple languages, including English, German, Polish, "
|
||||
"Spanish, Italian, French, Portuguese, and Hindi. "
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
_ = get_from_dict_or_env(values, "eleven_api_key", "ELEVEN_API_KEY")
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
task: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
elevenlabs = _import_elevenlabs()
|
||||
try:
|
||||
speech = elevenlabs.generate(text=task, model=self.model)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="bx", suffix=".wav", delete=False
|
||||
) as f:
|
||||
f.write(speech)
|
||||
return f.name
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
|
||||
|
||||
def play(self, speech_file: str) -> None:
|
||||
"""Play the text as speech."""
|
||||
elevenlabs = _import_elevenlabs()
|
||||
with open(speech_file, mode="rb") as f:
|
||||
speech = f.read()
|
||||
|
||||
elevenlabs.play(speech)
|
||||
|
||||
def stream_speech(self, query: str) -> None:
|
||||
"""Stream the text as speech as it is generated.
|
||||
Play the text in your speakers."""
|
||||
elevenlabs = _import_elevenlabs()
|
||||
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
|
||||
elevenlabs.stream(speech_stream)
|
||||
|
||||
def save(self, speech_file: str, path: str) -> None:
|
||||
"""Save the speech file to a path."""
|
||||
raise NotImplementedError("Saving not implemented for this tool.")
|
||||
|
||||
def __str__(self):
|
||||
return "ElevenLabsText2SpeechTool"
|
@ -0,0 +1,185 @@
|
||||
# Agent process automation
|
||||
system_prompt_1 = """You are a RPA(Robotic Process Automation) agent, you can write and test a RPA-Python-Code to connect different APPs together to reach a specific user query.
|
||||
|
||||
RPA-Python-Code:
|
||||
1. Each actions and triggers of APPs are defined as Action/Trigger-Functions, once you provide the specific_params for a function, then we will implement and test it **with some features that can influence outside-world and is transparent to you**.
|
||||
2. A RPA process is implemented as a workflow-function. the mainWorkflow function is activated when the trigger's conditions are reached.
|
||||
3. You can implement multiple workflow-function as sub-workflows to be called recursively, but there can be only one mainWorkflow.
|
||||
4. We will automatically test the workflows and actions with the Pinned-Data afer you change the specific_params.
|
||||
|
||||
Action/Trigger-Function: All the functions have the same following parameters:
|
||||
1.integration_name: where this function is from. A integration represent a list of actions and triggers from a APP.
|
||||
2.resource_name: This is the second category of a integration.
|
||||
3.operation_name: This is the third category of a integration. (integration->resouce->operation)
|
||||
4.specific_params: This is a json field, you will only see how to given this field after the above fields are selected.
|
||||
5.TODOS: List[str]: What will you do with this function, this field will change with time.
|
||||
6.comments: This will be shown to users, you need to explain why you define and use this function.
|
||||
|
||||
Workflow-Function:
|
||||
1. Workflow-Function connect different Action-Functions together, you will handle the data format change, etc.
|
||||
2. You must always have a mainWorkflow, whose inputs are a Trigger-function's output. If you define multiple triggers, The mainWorkflow will be activated when one of the trigger are activated, you must handle data type changes.
|
||||
3. You can define multiple subworkflow-function, Which whose inputs are provided by other workflows, You need to handle data-formats.
|
||||
|
||||
Testing-When-Implementing: We will **automatically** test all your actions, triggers and workflows with the pinned input data **at each time** once you change it.
|
||||
1. Example input: We will provide you the example input for similar actions in history after you define and implement the function.
|
||||
2. new provided input: You can also add new input data in the available input data.
|
||||
3. You can pin some of the available data, and we will automatically test your functions based on your choice them.
|
||||
4. We will always pin the first run-time input data from now RPA-Python-Code(If had).
|
||||
5.Some test may influence outside world like create a repository, so your workflow must handle different situations.
|
||||
|
||||
Data-Format: We ensure all the input/output data in transparent action functions have the format of List of Json: [{...}], length > 0
|
||||
1.All items in the list have the same json schema. The transparent will be activated for each item in the input-data. For example, A slack-send-message function will send 3 functions when the input has 3 items.
|
||||
2.All the json item must have a "json" field, in which are some custom fields.
|
||||
3.Some functions' json items have a additional "binary" field, which contains raw data of images, csv, etc.
|
||||
4.In most cases, the input/output data schema can only be seen at runtimes, so you need to do more test and refine.
|
||||
|
||||
Java-Script-Expression:
|
||||
1.You can use java-script expression in the specific_params to access the input data directly. Use it by a string startswith "=", and provide expression inside a "{{...}}" block.
|
||||
2. Use "{{$json["xxx"]}}" to obtain the "json" field in each item of the input data.
|
||||
3. You can use expression in "string" , "number", "boolean" and "json" type, such as:
|
||||
string: "=Hello {{$json["name"]}}, you are {{$json["age"]}} years old
|
||||
boolean: "={{$json["age"] > 20}}"
|
||||
number: "={{$json["year"] + 10.5}}"
|
||||
json: "={ "new_age":{{$json["year"] + 5}} }"
|
||||
|
||||
For example, in slack-send-message. The input looks like:
|
||||
[
|
||||
{
|
||||
"json": {
|
||||
"name": "Alice",
|
||||
"age": 15,
|
||||
}
|
||||
},
|
||||
{
|
||||
"json": {
|
||||
"name": "Jack",
|
||||
"age": 20,
|
||||
}
|
||||
}
|
||||
]
|
||||
When you set the field "message text" as "=Hello {{$json["name"]}}, you are {{$json["age"]}} years old.", then the message will be send as:
|
||||
[
|
||||
"Hello Alice, you are 15 years old.",
|
||||
"Hello Jack, you are 20 years old.",
|
||||
]
|
||||
|
||||
Based on the above information, the full RPA-Python-Code looks like the following:
|
||||
```
|
||||
from transparent_server import transparent_action, tranparent_trigger
|
||||
|
||||
# Specific_params: After you give function_define, we will provide json schemas of specific_params here.
|
||||
# Avaliable_datas: All the avaliable Datas: data_1, data_2...
|
||||
# Pinned_data_ID: All the input data you pinned and there execution result
|
||||
# ID=1, output: xxx
|
||||
# ID=3, output: xxx
|
||||
# Runtime_input_data: The runtime input of this function(first time)
|
||||
# Runtime_output_data: The corresponding output
|
||||
def action_1(input_data: [{...}]):
|
||||
# comments: some comments to users. Always give/change this when defining and implmenting
|
||||
# TODOS:
|
||||
# 1. I will provide the information in runtime
|
||||
# 2. I will test the node
|
||||
# 3. ...Always give/change this when defining and implmenting
|
||||
specific_params = {
|
||||
"key_1": value_1,
|
||||
"key_2": [
|
||||
{
|
||||
"subkey_2": value_2,
|
||||
}
|
||||
],
|
||||
"key_3": {
|
||||
"subkey_3": value_3,
|
||||
},
|
||||
# You will implement this after function-define
|
||||
}
|
||||
function = transparent_action(integration=xxx, resource=yyy, operation=zzz)
|
||||
output_data = function.run(input_data=input_data, params=params)
|
||||
return output_data
|
||||
|
||||
def action_2(input_data: [{...}]): ...
|
||||
def action_3(input_data: [{...}]): ...
|
||||
def action_4(input_data: [{...}]): ...
|
||||
|
||||
# Specific_params: After you give function_define, we will provide json schemas of specific_params here.
|
||||
# Trigger function has no input, and have the same output_format. So We will provide You the exmaple_output once you changed the code here.
|
||||
def trigger_1():
|
||||
# comments: some comments to users. Always give/change this when defining and implmenting
|
||||
# TODOS:
|
||||
# 1. I will provide the information in runtime
|
||||
# 2. I will test the node
|
||||
# 3. ...Always give/change this when defining and implmenting
|
||||
specific_params = {
|
||||
"key_1": value_1,
|
||||
"key_2": [
|
||||
{
|
||||
"subkey_2": value_2,
|
||||
}
|
||||
],
|
||||
"key_3": {
|
||||
"subkey_3": value_3,
|
||||
},
|
||||
# You will implement this after function-define
|
||||
}
|
||||
function = transparent_trigger(integration=xxx, resource=yyy, operation=zzz)
|
||||
output_data = function.run(input_data=input_data, params=params)
|
||||
return output_data
|
||||
|
||||
def trigger_2(input_data: [{...}]): ...
|
||||
def trigger_3(input_data: [{...}]): ...
|
||||
|
||||
# subworkflow inputs the same json-schema, can be called by another workflow.
|
||||
def subworkflow_1(father_workflow_input: [{...}]): ...
|
||||
def subworkflow_2(father_workflow_input: [{...}]): ...
|
||||
|
||||
# If you defined the trigger node, we will show you the mocked trigger input here.
|
||||
# If you have implemented the workflow, we will automatically run the workflow for all the mock trigger-input and tells you the result.
|
||||
def mainWorkflow(trigger_input: [{...}]):
|
||||
# comments: some comments to users. Always give/change this when defining and implmenting
|
||||
# TODOS:
|
||||
# 1. I will provide the information in runtime
|
||||
# 2. I will test the node
|
||||
# 3. ...Always give/change this when defining and implmenting
|
||||
|
||||
# some complex logics here
|
||||
output_data = trigger_input
|
||||
|
||||
return output_data
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
system_prompt_2 = """You will define and implement functions progressively for many steps. At each step, you can do one of the following actions:
|
||||
1. functions_define: Define a list of functions(Action and Trigger). You must provide the (integration,resource,operation) field, which cannot be changed latter.
|
||||
2. function_implement: After function define, we will provide you the specific_param schema of the target function. You can provide(or override) the specific_param by this function. We will show your available test_data after you implement functions.
|
||||
3. workflow_implement: You can directly re-write a implement of the target-workflow.
|
||||
4. add_test_data: Beside the provided hostory data, you can also add your custom test data for a function.
|
||||
5. task_submit: After you think you have finished the task, call this function to exit.
|
||||
|
||||
Remember:
|
||||
1.Always provide thought, plans and criticisim before giving an action.
|
||||
2.Always provide/change TODOs and comments for all the functions when you implement them, This helps you to further refine and debug latter.
|
||||
3.We will test functions automatically, you only need to change the pinned data.
|
||||
|
||||
"""
|
||||
|
||||
system_prompt_3 = """The user query:
|
||||
{{user_query}}
|
||||
|
||||
You have access to use the following actions and triggers:
|
||||
|
||||
{{flatten_tools}}
|
||||
"""
|
||||
|
||||
history_prompt = """In the {{action_count}}'s time, You made the following action:
|
||||
{{action}}
|
||||
"""
|
||||
|
||||
user_prompt = """Now the codes looks like this:
|
||||
```
|
||||
{{now_codes}}
|
||||
```
|
||||
|
||||
{{refine_prompt}}
|
||||
|
||||
Give your next action together with thought, plans and criticisim:
|
||||
"""
|
@ -1,17 +1,14 @@
|
||||
from swarms.swarms.dialogue_simulator import DialogueSimulator
|
||||
from swarms.swarms.autoscaler import AutoScaler
|
||||
|
||||
# from swarms.swarms.orchestrate import Orchestrator
|
||||
from swarms.structs.autoscaler import AutoScaler
|
||||
from swarms.swarms.god_mode import GodMode
|
||||
from swarms.swarms.simple_swarm import SimpleSwarm
|
||||
from swarms.swarms.multi_agent_debate import MultiAgentDebate, select_speaker
|
||||
from swarms.swarms.multi_agent_collab import MultiAgentCollaboration
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DialogueSimulator",
|
||||
"AutoScaler",
|
||||
# "Orchestrator",
|
||||
"GodMode",
|
||||
"SimpleSwarm",
|
||||
"MultiAgentDebate",
|
||||
"select_speaker",
|
||||
"MultiAgentCollaboration",
|
||||
]
|
||||
|
@ -1,76 +0,0 @@
|
||||
from swarms.structs.flow import Flow
|
||||
|
||||
|
||||
# Define a selection function
|
||||
def select_speaker(step: int, agents) -> int:
|
||||
# This function selects the speaker in a round-robin fashion
|
||||
return step % len(agents)
|
||||
|
||||
|
||||
class MultiAgentDebate:
|
||||
"""
|
||||
MultiAgentDebate
|
||||
|
||||
|
||||
Args:
|
||||
agents: Flow
|
||||
selection_func: callable
|
||||
max_iters: int
|
||||
|
||||
Usage:
|
||||
>>> from swarms import MultiAgentDebate
|
||||
>>> from swarms.structs.flow import Flow
|
||||
>>> agents = Flow()
|
||||
>>> agents.append(lambda x: x)
|
||||
>>> agents.append(lambda x: x)
|
||||
>>> agents.append(lambda x: x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: Flow,
|
||||
selection_func: callable = select_speaker,
|
||||
max_iters: int = None,
|
||||
):
|
||||
self.agents = agents
|
||||
self.selection_func = selection_func
|
||||
self.max_iters = max_iters
|
||||
|
||||
def inject_agent(self, agent):
|
||||
"""Injects an agent into the debate"""
|
||||
self.agents.append(agent)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task: str,
|
||||
):
|
||||
"""
|
||||
MultiAgentDebate
|
||||
|
||||
Args:
|
||||
task: str
|
||||
|
||||
Returns:
|
||||
results: list
|
||||
|
||||
"""
|
||||
results = []
|
||||
for i in range(self.max_iters or len(self.agents)):
|
||||
speaker_idx = self.selection_func(i, self.agents)
|
||||
speaker = self.agents[speaker_idx]
|
||||
response = speaker(task)
|
||||
results.append({"response": response})
|
||||
return results
|
||||
|
||||
def update_task(self, task: str):
|
||||
"""Update the task"""
|
||||
self.task = task
|
||||
|
||||
def format_results(self, results):
|
||||
"""Format the results"""
|
||||
formatted_results = "\n".join(
|
||||
[f"Agent responded: {result['response']}" for result in results]
|
||||
)
|
||||
|
||||
return formatted_results
|
@ -0,0 +1,154 @@
|
||||
from enum import Enum, unique, auto
|
||||
import abc
|
||||
import hashlib
|
||||
import re
|
||||
from typing import List, Optional
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@unique
|
||||
class LLMStatusCode(Enum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
|
||||
|
||||
@unique
|
||||
class NodeType(Enum):
|
||||
action = auto()
|
||||
trigger = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class WorkflowType(Enum):
|
||||
Main = auto()
|
||||
Sub = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class ToolCallStatus(Enum):
|
||||
ToolCallSuccess = auto()
|
||||
ToolCallPartlySuccess = auto()
|
||||
NoSuchTool = auto()
|
||||
NoSuchFunction = auto()
|
||||
InputCannotParsed = auto()
|
||||
|
||||
UndefinedParam = auto()
|
||||
ParamTypeError = auto()
|
||||
UnSupportedParam = auto()
|
||||
UnsupportedExpression = auto()
|
||||
ExpressionError = auto()
|
||||
RequiredParamUnprovided = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class TestDataType(Enum):
|
||||
NoInput = auto()
|
||||
TriggerInput = auto()
|
||||
ActionInput = auto()
|
||||
SubWorkflowInput = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class RunTimeStatus(Enum):
|
||||
FunctionExecuteSuccess = auto()
|
||||
TriggerAcivatedSuccess = auto()
|
||||
ErrorRaisedHere = auto()
|
||||
ErrorRaisedInner = auto()
|
||||
DidNotImplemented = auto()
|
||||
DidNotBeenCalled = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""
|
||||
Responsible for handling the data structure of [{}]
|
||||
"""
|
||||
|
||||
data_type: TestDataType = TestDataType.ActionInput
|
||||
|
||||
input_data: Optional[list] = field(default_factory=lambda: [])
|
||||
|
||||
runtime_status: RunTimeStatus = RunTimeStatus.DidNotBeenCalled
|
||||
visit_times: int = 0
|
||||
|
||||
error_message: str = ""
|
||||
output_data: Optional[list] = field(default_factory=lambda: [])
|
||||
|
||||
def load_from_json(self):
|
||||
pass
|
||||
|
||||
def to_json(self):
|
||||
pass
|
||||
|
||||
def to_str(self):
|
||||
prompt = f"""
|
||||
This function has been executed for {self.visit_times} times. Last execution:
|
||||
1.Status: {self.runtime_status.name}
|
||||
2.Input:
|
||||
{self.input_data}
|
||||
|
||||
3.Output:
|
||||
{self.output_data}"""
|
||||
return prompt
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
content: str = ""
|
||||
thought: str = ""
|
||||
plan: List[str] = field(default_factory=lambda: [])
|
||||
criticism: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict = field(default_factory=lambda: {})
|
||||
|
||||
tool_output_status: ToolCallStatus = ToolCallStatus.ToolCallSuccess
|
||||
tool_output: str = ""
|
||||
|
||||
def to_json(self):
|
||||
try:
|
||||
tool_output = json.loads(self.tool_output)
|
||||
except:
|
||||
tool_output = self.tool_output
|
||||
return {
|
||||
"thought": self.thought,
|
||||
"plan": self.plan,
|
||||
"criticism": self.criticism,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_input": self.tool_input,
|
||||
"tool_output_status": self.tool_output_status.name,
|
||||
"tool_output": tool_output,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class userQuery:
|
||||
task: str
|
||||
additional_information: List[str] = field(default_factory=lambda: [])
|
||||
refine_prompt: str = field(default_factory=lambda: "")
|
||||
|
||||
def print_self(self):
|
||||
lines = [self.task]
|
||||
for info in self.additional_information:
|
||||
lines.append(f"- {info}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class Singleton(abc.ABCMeta, type):
|
||||
"""
|
||||
Singleton metaclass for ensuring only one instance of a class.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""Call method for the singleton metaclass."""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
||||
"""
|
||||
Abstract singleton class for ensuring only one instance of a class.
|
||||
"""
|
@ -0,0 +1,501 @@
|
||||
"""Logging modules"""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
from logging import LogRecord
|
||||
from typing import Any
|
||||
|
||||
from colorama import Fore, Style
|
||||
from swarms.utils.apa import Action, ToolCallStatus
|
||||
|
||||
|
||||
# from autogpt.speech import say_text
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
"""
|
||||
Initializes a new instance of the class with the specified file name, mode, encoding, and delay settings.
|
||||
|
||||
Parameters:
|
||||
filename (str): The name of the file to be opened.
|
||||
mode (str, optional): The mode in which the file is opened. Defaults to "a" (append).
|
||||
encoding (str, optional): The encoding used to read or write the file. Defaults to None.
|
||||
delay (bool, optional): If True, the file opening is delayed until the first IO operation. Defaults to False.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Writes the formatted log record to a JSON file.
|
||||
|
||||
Parameters:
|
||||
record (LogRecord): The log record to be emitted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
json_data = json.loads(self.format(record))
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""
|
||||
Format the given record and return the message.
|
||||
|
||||
Args:
|
||||
record (object): The log record to be formatted.
|
||||
|
||||
Returns:
|
||||
str: The formatted message from the record.
|
||||
"""
|
||||
return record.msg
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
Logger that handle titles in different colors.
|
||||
Outputs logs in console, activity.log, and errors.log
|
||||
For console handler: simulates typing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the class and sets up the logging configuration.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# create log directory if it doesn't exist
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
log_file = "activity.log"
|
||||
error_file = "error.log"
|
||||
|
||||
console_formatter = AutoGptFormatter("%(title_color)s %(message)s")
|
||||
|
||||
# Create a handler for console which simulate typing
|
||||
self.typing_console_handler = TypingConsoleHandler()
|
||||
# self.typing_console_handler = ConsoleHandler()
|
||||
self.typing_console_handler.setLevel(logging.INFO)
|
||||
self.typing_console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Create a handler for console without typing simulation
|
||||
self.console_handler = ConsoleHandler()
|
||||
self.console_handler.setLevel(logging.DEBUG)
|
||||
self.console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Info handler in activity.log
|
||||
self.file_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, log_file), "a", "utf-8"
|
||||
)
|
||||
self.file_handler.setLevel(logging.DEBUG)
|
||||
info_formatter = AutoGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(title)s %(message_no_color)s"
|
||||
)
|
||||
self.file_handler.setFormatter(info_formatter)
|
||||
|
||||
# Error handler error.log
|
||||
error_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, error_file), "a", "utf-8"
|
||||
)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
error_formatter = AutoGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s"
|
||||
" %(message_no_color)s"
|
||||
)
|
||||
error_handler.setFormatter(error_formatter)
|
||||
|
||||
self.typing_logger = logging.getLogger("TYPER")
|
||||
self.typing_logger.addHandler(self.typing_console_handler)
|
||||
# self.typing_logger.addHandler(self.console_handler)
|
||||
self.typing_logger.addHandler(self.file_handler)
|
||||
self.typing_logger.addHandler(error_handler)
|
||||
self.typing_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.logger = logging.getLogger("LOGGER")
|
||||
self.logger.addHandler(self.console_handler)
|
||||
self.logger.addHandler(self.file_handler)
|
||||
self.logger.addHandler(error_handler)
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.json_logger = logging.getLogger("JSON_LOGGER")
|
||||
self.json_logger.addHandler(self.file_handler)
|
||||
self.json_logger.addHandler(error_handler)
|
||||
self.json_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.speak_mode = False
|
||||
self.chat_plugins = []
|
||||
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
"""
|
||||
Logs a message to the typewriter.
|
||||
|
||||
Args:
|
||||
title (str, optional): The title of the log message. Defaults to "".
|
||||
title_color (str, optional): The color of the title. Defaults to "".
|
||||
content (str or list, optional): The content of the log message. Defaults to "".
|
||||
speak_text (bool, optional): Whether to speak the log message. Defaults to False.
|
||||
level (int, optional): The logging level of the message. Defaults to logging.INFO.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.chat_plugins:
|
||||
plugin.report(f"{title}. {content}")
|
||||
|
||||
if content:
|
||||
if isinstance(content, list):
|
||||
content = " ".join(content)
|
||||
else:
|
||||
content = ""
|
||||
|
||||
self.typing_logger.log(
|
||||
level, content, extra={"title": title, "color": title_color}
|
||||
)
|
||||
|
||||
def debug(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
"""
|
||||
Logs a debug message.
|
||||
|
||||
Args:
|
||||
message (str): The debug message to log.
|
||||
title (str, optional): The title of the log message. Defaults to "".
|
||||
title_color (str, optional): The color of the log message title. Defaults to "".
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._log(title, title_color, message, logging.DEBUG)
|
||||
|
||||
def info(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
"""
|
||||
Logs an informational message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
title (str, optional): The title of the log message. Defaults to "".
|
||||
title_color (str, optional): The color of the log title. Defaults to "".
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._log(title, title_color, message, logging.INFO)
|
||||
|
||||
def warn(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
"""
|
||||
Logs a warning message.
|
||||
|
||||
Args:
|
||||
message (str): The warning message.
|
||||
title (str, optional): The title of the warning message. Defaults to "".
|
||||
title_color (str, optional): The color of the title. Defaults to "".
|
||||
"""
|
||||
self._log(title, title_color, message, logging.WARN)
|
||||
|
||||
def error(self, title, message=""):
|
||||
"""
|
||||
Logs an error message with the given title and optional message.
|
||||
|
||||
Parameters:
|
||||
title (str): The title of the error message.
|
||||
message (str, optional): The optional additional message for the error. Defaults to an empty string.
|
||||
"""
|
||||
self._log(title, Fore.RED, message, logging.ERROR)
|
||||
|
||||
def _log(
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
):
|
||||
"""
|
||||
Logs a message with the given title and message at the specified log level.
|
||||
|
||||
Parameters:
|
||||
title (str): The title of the log message. Defaults to an empty string.
|
||||
title_color (str): The color of the log message title. Defaults to an empty string.
|
||||
message (str): The log message. Defaults to an empty string.
|
||||
level (int): The log level. Defaults to logging.INFO.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if message:
|
||||
if isinstance(message, list):
|
||||
message = " ".join(message)
|
||||
self.logger.log(
|
||||
level, message, extra={"title": str(title), "color": str(title_color)}
|
||||
)
|
||||
|
||||
def set_level(self, level):
|
||||
"""
|
||||
Set the level of the logger and the typing_logger.
|
||||
|
||||
Args:
|
||||
level: The level to set the logger to.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.logger.setLevel(level)
|
||||
self.typing_logger.setLevel(level)
|
||||
|
||||
def double_check(self, additionalText=None):
|
||||
"""
|
||||
A function that performs a double check on the configuration.
|
||||
|
||||
Parameters:
|
||||
additionalText (str, optional): Additional text to be included in the double check message.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not additionalText:
|
||||
additionalText = (
|
||||
"Please ensure you've setup and configured everything"
|
||||
" correctly. Read https://github.com/Torantulino/Auto-GPT#readme to "
|
||||
"double check. You can also create a github issue or join the discord"
|
||||
" and ask there!"
|
||||
)
|
||||
|
||||
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
||||
|
||||
def log_json(self, data: Any, file_name: str) -> None:
|
||||
"""
|
||||
Logs the given JSON data to a specified file.
|
||||
|
||||
Args:
|
||||
data (Any): The JSON data to be logged.
|
||||
file_name (str): The name of the file to log the data to.
|
||||
|
||||
Returns:
|
||||
None: This function does not return anything.
|
||||
"""
|
||||
# Define log directory
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
|
||||
# Create a handler for JSON files
|
||||
json_file_path = os.path.join(log_dir, file_name)
|
||||
json_data_handler = JsonFileHandler(json_file_path)
|
||||
json_data_handler.setFormatter(JsonFormatter())
|
||||
|
||||
# Log the JSON data using the custom file handler
|
||||
self.json_logger.addHandler(json_data_handler)
|
||||
self.json_logger.debug(data)
|
||||
self.json_logger.removeHandler(json_data_handler)
|
||||
|
||||
def get_log_directory(self):
|
||||
"""
|
||||
Returns the absolute path to the log directory.
|
||||
|
||||
Returns:
|
||||
str: The absolute path to the log directory.
|
||||
"""
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
return os.path.abspath(log_dir)
|
||||
|
||||
|
||||
"""
|
||||
Output stream to console using simulated typing
|
||||
"""
|
||||
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record):
|
||||
"""
|
||||
Emit a log record to the console with simulated typing effect.
|
||||
|
||||
Args:
|
||||
record (LogRecord): The log record to be emitted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs while emitting the log record.
|
||||
"""
|
||||
min_typing_speed = 0.05
|
||||
max_typing_speed = 0.10
|
||||
# min_typing_speed = 0.005
|
||||
# max_typing_speed = 0.010
|
||||
|
||||
msg = self.format(record)
|
||||
try:
|
||||
# replace enter & indent with other symbols
|
||||
transfer_enter = "<ENTER>"
|
||||
msg_transfered = str(msg).replace("\n", transfer_enter)
|
||||
transfer_space = "<4SPACE>"
|
||||
msg_transfered = str(msg_transfered).replace(" ", transfer_space)
|
||||
words = msg_transfered.split()
|
||||
words = [word.replace(transfer_enter, "\n") for word in words]
|
||||
words = [word.replace(transfer_space, " ") for word in words]
|
||||
|
||||
for i, word in enumerate(words):
|
||||
print(word, end="", flush=True)
|
||||
if i < len(words) - 1:
|
||||
print(" ", end="", flush=True)
|
||||
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
|
||||
time.sleep(typing_speed)
|
||||
# type faster after each word
|
||||
min_typing_speed = min_typing_speed * 0.95
|
||||
max_typing_speed = max_typing_speed * 0.95
|
||||
print()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record) -> None:
|
||||
"""
|
||||
Emit the log record.
|
||||
|
||||
Args:
|
||||
record (logging.LogRecord): The log record to emit.
|
||||
|
||||
Returns:
|
||||
None: This function does not return anything.
|
||||
"""
|
||||
msg = self.format(record)
|
||||
try:
|
||||
print(msg)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class AutoGptFormatter(logging.Formatter):
|
||||
"""
|
||||
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
|
||||
To use this formatter, make sure to pass 'color', 'title' as log extras.
|
||||
"""
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""
|
||||
Formats a log record into a string representation.
|
||||
|
||||
Args:
|
||||
record (LogRecord): The log record to be formatted.
|
||||
|
||||
Returns:
|
||||
str: The formatted log record as a string.
|
||||
"""
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
|
||||
# Add this line to set 'title' to an empty string if it doesn't exist
|
||||
record.title = getattr(record, "title", "")
|
||||
|
||||
if hasattr(record, "msg"):
|
||||
record.message_no_color = remove_color_codes(getattr(record, "msg"))
|
||||
else:
|
||||
record.message_no_color = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
"""
|
||||
Removes color codes from a given string.
|
||||
|
||||
Args:
|
||||
s (str): The string from which to remove color codes.
|
||||
|
||||
Returns:
|
||||
str: The string with color codes removed.
|
||||
"""
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", s)
|
||||
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
def print_action_base(action: Action):
|
||||
"""
|
||||
Print the different properties of an Action object.
|
||||
|
||||
Parameters:
|
||||
action (Action): The Action object to print.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if action.content != "":
|
||||
logger.typewriter_log(f"content:", Fore.YELLOW, f"{action.content}")
|
||||
logger.typewriter_log(f"Thought:", Fore.YELLOW, f"{action.thought}")
|
||||
if len(action.plan) > 0:
|
||||
logger.typewriter_log(
|
||||
f"Plan:",
|
||||
Fore.YELLOW,
|
||||
)
|
||||
for line in action.plan:
|
||||
line = line.lstrip("- ")
|
||||
logger.typewriter_log("- ", Fore.GREEN, line.strip())
|
||||
logger.typewriter_log(f"Criticism:", Fore.YELLOW, f"{action.criticism}")
|
||||
|
||||
|
||||
def print_action_tool(action: Action):
|
||||
"""
|
||||
Prints the details of an action tool.
|
||||
|
||||
Args:
|
||||
action (Action): The action object containing the tool details.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}")
|
||||
logger.typewriter_log(f"Tool Input:", Fore.BLUE, f"{action.tool_input}")
|
||||
|
||||
output = action.tool_output if action.tool_output != "" else "None"
|
||||
logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}")
|
||||
|
||||
color = Fore.RED
|
||||
if action.tool_output_status == ToolCallStatus.ToolCallSuccess:
|
||||
color = Fore.GREEN
|
||||
elif action.tool_output_status == ToolCallStatus.InputCannotParsed:
|
||||
color = Fore.YELLOW
|
||||
|
||||
logger.typewriter_log(
|
||||
f"Tool Call Status:",
|
||||
Fore.BLUE,
|
||||
f"{color}{action.tool_output_status.name}{Style.RESET_ALL}",
|
||||
)
|
@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, mock_open
|
||||
from swarms.models.eleven_labs import ElevenLabsText2SpeechTool, ElevenLabsModel
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Define some test data
|
||||
SAMPLE_TEXT = "Hello, this is a test."
|
||||
API_KEY = os.environ.get("ELEVEN_API_KEY")
|
||||
EXPECTED_SPEECH_FILE = "expected_speech.wav"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eleven_labs_tool():
|
||||
return ElevenLabsText2SpeechTool()
|
||||
|
||||
|
||||
# Basic functionality tests
|
||||
def test_run_text_to_speech(eleven_labs_tool):
|
||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert isinstance(speech_file, str)
|
||||
assert speech_file.endswith(".wav")
|
||||
|
||||
|
||||
def test_play_speech(eleven_labs_tool):
|
||||
with patch("builtins.open", mock_open(read_data="fake_audio_data")):
|
||||
eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
|
||||
|
||||
|
||||
def test_stream_speech(eleven_labs_tool):
|
||||
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file:
|
||||
eleven_labs_tool.stream_speech(SAMPLE_TEXT)
|
||||
mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False)
|
||||
|
||||
|
||||
# Testing fixture and environment variables
|
||||
def test_api_key_validation(eleven_labs_tool):
|
||||
with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY):
|
||||
values = {"eleven_api_key": None}
|
||||
validated_values = eleven_labs_tool.validate_environment(values)
|
||||
assert "eleven_api_key" in validated_values
|
||||
|
||||
|
||||
# Mocking the external library
|
||||
def test_run_text_to_speech_with_mock(eleven_labs_tool):
|
||||
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch(
|
||||
"your_module._import_elevenlabs"
|
||||
) as mock_elevenlabs:
|
||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
||||
mock_elevenlabs_instance.generate.return_value = b"fake_audio_data"
|
||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert mock_file.call_args[1]["suffix"] == ".wav"
|
||||
assert mock_file.call_args[1]["delete"] is False
|
||||
assert mock_file().write.call_args[0][0] == b"fake_audio_data"
|
||||
|
||||
|
||||
# Exception testing
|
||||
def test_run_text_to_speech_error_handling(eleven_labs_tool):
|
||||
with patch("your_module._import_elevenlabs") as mock_elevenlabs:
|
||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
||||
mock_elevenlabs_instance.generate.side_effect = Exception("Test Exception")
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Error while running ElevenLabsText2SpeechTool: Test Exception",
|
||||
):
|
||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
|
||||
|
||||
# Parameterized testing
|
||||
@pytest.mark.parametrize(
|
||||
"model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL]
|
||||
)
|
||||
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model):
|
||||
eleven_labs_tool.model = model
|
||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert isinstance(speech_file, str)
|
||||
assert speech_file.endswith(".wav")
|
@ -0,0 +1,26 @@
|
||||
from swarms.models import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"Anthropic",
|
||||
"Petals",
|
||||
"Mistral",
|
||||
"OpenAI",
|
||||
"AzureOpenAI",
|
||||
"OpenAIChat",
|
||||
"Zephyr",
|
||||
"Idefics",
|
||||
# "Kosmos",
|
||||
"Vilt",
|
||||
"Nougat",
|
||||
"LayoutLMDocumentQA",
|
||||
"BioGPT",
|
||||
"HuggingfaceLLM",
|
||||
"MPT7B",
|
||||
"WizardLLMStoryTeller",
|
||||
# "GPT4Vision",
|
||||
# "Dalle3",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -1,76 +1,168 @@
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.structs import Flow
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.swarms.multi_agent_collab import (
|
||||
MultiAgentCollaboration,
|
||||
Worker,
|
||||
select_next_speaker,
|
||||
select_next_speaker_director,
|
||||
select_speaker_round_table,
|
||||
)
|
||||
|
||||
# Sample agents for testing
|
||||
agent1 = Flow(llm=OpenAIChat(), max_loops=2)
|
||||
agent2 = Flow(llm=OpenAIChat(), max_loops=2)
|
||||
agents = [agent1, agent2]
|
||||
|
||||
def test_multiagentcollaboration_initialization():
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
assert isinstance(multiagentcollaboration, MultiAgentCollaboration)
|
||||
assert len(multiagentcollaboration.agents) == 5
|
||||
assert multiagentcollaboration._step == 0
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.reset")
|
||||
def test_multiagentcollaboration_reset(mock_reset):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.reset()
|
||||
assert mock_reset.call_count == 5
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.run")
|
||||
def test_multiagentcollaboration_inject(mock_run):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.inject("Agent 1", "Hello, world!")
|
||||
assert multiagentcollaboration._step == 1
|
||||
assert mock_run.call_count == 5
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.send")
|
||||
@patch("swarms.workers.Worker.receive")
|
||||
def test_multiagentcollaboration_step(mock_receive, mock_send):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.step()
|
||||
assert multiagentcollaboration._step == 1
|
||||
assert mock_send.call_count == 5
|
||||
assert mock_receive.call_count == 25
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.bid")
|
||||
def test_multiagentcollaboration_ask_for_bid(mock_bid):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
result = multiagentcollaboration.ask_for_bid(Worker)
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.bid")
|
||||
def test_multiagentcollaboration_select_next_speaker(mock_bid):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
result = multiagentcollaboration.select_next_speaker(1, [Worker] * 5)
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.send")
|
||||
@patch("swarms.workers.Worker.receive")
|
||||
def test_multiagentcollaboration_run(mock_receive, mock_send):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.run(max_iters=5)
|
||||
assert multiagentcollaboration._step == 6
|
||||
assert mock_send.call_count == 30
|
||||
assert mock_receive.call_count == 150
|
||||
|
||||
@pytest.fixture
|
||||
def collaboration():
|
||||
return MultiAgentCollaboration(agents)
|
||||
|
||||
|
||||
def test_collaboration_initialization(collaboration):
|
||||
assert len(collaboration.agents) == 2
|
||||
assert callable(collaboration.select_next_speaker)
|
||||
assert collaboration.max_iters == 10
|
||||
assert collaboration.results == []
|
||||
assert collaboration.logging == True
|
||||
|
||||
|
||||
def test_reset(collaboration):
|
||||
collaboration.reset()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.step == 0
|
||||
|
||||
|
||||
def test_inject(collaboration):
|
||||
collaboration.inject("TestName", "TestMessage")
|
||||
for agent in collaboration.agents:
|
||||
assert "TestName" in agent.history[-1]
|
||||
assert "TestMessage" in agent.history[-1]
|
||||
|
||||
|
||||
def test_inject_agent(collaboration):
|
||||
agent3 = Flow(llm=OpenAIChat(), max_loops=2)
|
||||
collaboration.inject_agent(agent3)
|
||||
assert len(collaboration.agents) == 3
|
||||
assert agent3 in collaboration.agents
|
||||
|
||||
|
||||
def test_step(collaboration):
|
||||
collaboration.step()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.step == 1
|
||||
|
||||
|
||||
def test_ask_for_bid(collaboration):
|
||||
agent = Mock()
|
||||
agent.bid.return_value = "<5>"
|
||||
bid = collaboration.ask_for_bid(agent)
|
||||
assert bid == 5
|
||||
|
||||
|
||||
def test_select_next_speaker(collaboration):
|
||||
collaboration.select_next_speaker = Mock(return_value=0)
|
||||
idx = collaboration.select_next_speaker(1, collaboration.agents)
|
||||
assert idx == 0
|
||||
|
||||
|
||||
def test_run(collaboration):
|
||||
collaboration.run()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.step == collaboration.max_iters
|
||||
|
||||
|
||||
def test_format_results(collaboration):
|
||||
collaboration.results = [{"agent": "Agent1", "response": "Response1"}]
|
||||
formatted_results = collaboration.format_results(collaboration.results)
|
||||
assert "Agent1 responded: Response1" in formatted_results
|
||||
|
||||
|
||||
def test_save_and_load(collaboration):
|
||||
collaboration.save()
|
||||
loaded_state = collaboration.load()
|
||||
assert loaded_state["_step"] == collaboration._step
|
||||
assert loaded_state["results"] == collaboration.results
|
||||
|
||||
|
||||
def test_performance(collaboration):
|
||||
performance_data = collaboration.performance()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.name in performance_data
|
||||
assert "metrics" in performance_data[agent.name]
|
||||
|
||||
|
||||
def test_set_interaction_rules(collaboration):
|
||||
rules = {"rule1": "action1", "rule2": "action2"}
|
||||
collaboration.set_interaction_rules(rules)
|
||||
assert hasattr(collaboration, "interaction_rules")
|
||||
assert collaboration.interaction_rules == rules
|
||||
|
||||
|
||||
def test_set_interaction_rules(collaboration):
|
||||
rules = {"rule1": "action1", "rule2": "action2"}
|
||||
collaboration.set_interaction_rules(rules)
|
||||
assert hasattr(collaboration, "interaction_rules")
|
||||
assert collaboration.interaction_rules == rules
|
||||
|
||||
|
||||
def test_repr(collaboration):
|
||||
repr_str = repr(collaboration)
|
||||
assert isinstance(repr_str, str)
|
||||
assert "MultiAgentCollaboration" in repr_str
|
||||
|
||||
|
||||
def test_load(collaboration):
|
||||
state = {"step": 5, "results": [{"agent": "Agent1", "response": "Response1"}]}
|
||||
with open(collaboration.saved_file_path_name, "w") as file:
|
||||
json.dump(state, file)
|
||||
|
||||
loaded_state = collaboration.load()
|
||||
assert loaded_state["_step"] == state["step"]
|
||||
assert loaded_state["results"] == state["results"]
|
||||
|
||||
|
||||
def test_save(collaboration, tmp_path):
|
||||
collaboration.saved_file_path_name = tmp_path / "test_save.json"
|
||||
collaboration.save()
|
||||
|
||||
with open(collaboration.saved_file_path_name, "r") as file:
|
||||
saved_data = json.load(file)
|
||||
|
||||
assert saved_data["_step"] == collaboration._step
|
||||
assert saved_data["results"] == collaboration.results
|
||||
|
||||
|
||||
# Add more tests here...
|
||||
|
||||
|
||||
# Example of parameterized test for different selection functions
|
||||
@pytest.mark.parametrize(
|
||||
"selection_function", [select_next_speaker_director, select_speaker_round_table]
|
||||
)
|
||||
def test_selection_functions(collaboration, selection_function):
|
||||
collaboration.select_next_speaker = selection_function
|
||||
assert callable(collaboration.select_next_speaker)
|
||||
|
||||
|
||||
# Add more parameterized tests for different scenarios...
|
||||
|
||||
|
||||
# Example of exception testing
|
||||
def test_exception_handling(collaboration):
|
||||
agent = Mock()
|
||||
agent.bid.side_effect = ValueError("Invalid bid")
|
||||
with pytest.raises(ValueError):
|
||||
collaboration.ask_for_bid(agent)
|
||||
|
||||
|
||||
# Add more exception testing...
|
||||
|
||||
|
||||
# Example of environment variable testing (if applicable)
|
||||
@pytest.mark.parametrize("env_var", ["ENV_VAR_1", "ENV_VAR_2"])
|
||||
def test_environment_variables(collaboration, monkeypatch, env_var):
|
||||
monkeypatch.setenv(env_var, "test_value")
|
||||
assert os.getenv(env_var) == "test_value"
|
||||
|
Loading…
Reference in new issue