commit
7383a1eb17
@ -0,0 +1,41 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.7", "3.9", "3.10", "3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install --upgrade swarms
|
||||
python -m pip install flake8 pytest
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
find ./tests -name '*.py' -exec pytest {} \;
|
@ -0,0 +1,21 @@
|
||||
Developers
|
||||
|
||||
Install pre-commit (https://pre-commit.com/)
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Check that it's installed
|
||||
|
||||
```bash
|
||||
pre-commit --versioni
|
||||
```
|
||||
|
||||
This repository already has a pre-commit configuration. To install the hooks, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Now when you make a git commit, the black code formatter and ruff linter will run.
|
@ -0,0 +1,82 @@
|
||||
# ElevenLabsText2SpeechTool Documentation
|
||||
|
||||
## Table of Contents
|
||||
1. [Introduction](#introduction)
|
||||
2. [Class Overview](#class-overview)
|
||||
- [Attributes](#attributes)
|
||||
3. [Installation](#installation)
|
||||
4. [Usage](#usage)
|
||||
- [Initialization](#initialization)
|
||||
- [Converting Text to Speech](#converting-text-to-speech)
|
||||
- [Playing and Streaming Speech](#playing-and-streaming-speech)
|
||||
5. [Exception Handling](#exception-handling)
|
||||
6. [Advanced Usage](#advanced-usage)
|
||||
7. [Contributing](#contributing)
|
||||
8. [References](#references)
|
||||
|
||||
## 1. Introduction <a name="introduction"></a>
|
||||
The `ElevenLabsText2SpeechTool` is a Python class designed to simplify the process of converting text to speech using the Eleven Labs Text2Speech API. This tool is a wrapper around the API and provides a convenient interface for generating speech from text. It supports multiple languages, making it suitable for a wide range of applications, including voice assistants, audio content generation, and more.
|
||||
|
||||
## 2. Class Overview <a name="class-overview"></a>
|
||||
### Attributes <a name="attributes"></a>
|
||||
- `model` (Union[ElevenLabsModel, str]): The model to use for text to speech. Defaults to `ElevenLabsModel.MULTI_LINGUAL`.
|
||||
- `name` (str): The name of the tool. Defaults to `"eleven_labs_text2speech"`.
|
||||
- `description` (str): A brief description of the tool. Defaults to a detailed explanation of its functionality.
|
||||
|
||||
## 3. Installation <a name="installation"></a>
|
||||
To use the `ElevenLabsText2SpeechTool`, you need to install the required dependencies and have access to the Eleven Labs Text2Speech API. Follow these steps:
|
||||
|
||||
1. Install the `elevenlabs` library:
|
||||
```
|
||||
pip install elevenlabs
|
||||
```
|
||||
|
||||
2. Install the `swarms` library
|
||||
`pip install swarms`
|
||||
|
||||
3. Set up your API key by following the instructions at [Eleven Labs Documentation](https://docs.elevenlabs.io/welcome/introduction).
|
||||
|
||||
## 4. Usage <a name="usage"></a>
|
||||
### Initialization <a name="initialization"></a>
|
||||
To get started, create an instance of the `ElevenLabsText2SpeechTool`. You can customize the `model` attribute if needed.
|
||||
|
||||
```python
|
||||
from swarms.models import ElevenLabsText2SpeechTool
|
||||
|
||||
stt = ElevenLabsText2SpeechTool(model=ElevenLabsModel.MONO_LINGUAL)
|
||||
```
|
||||
|
||||
### Converting Text to Speech <a name="converting-text-to-speech"></a>
|
||||
You can use the `run` method to convert text to speech. It returns the path to the generated speech file.
|
||||
|
||||
```python
|
||||
speech_file = stt.run("Hello, this is a test.")
|
||||
```
|
||||
|
||||
### Playing and Streaming Speech <a name="playing-and-streaming-speech"></a>
|
||||
- Use the `play` method to play the generated speech file.
|
||||
|
||||
```python
|
||||
stt.play(speech_file)
|
||||
```
|
||||
|
||||
- Use the `stream_speech` method to stream the text as speech. It plays the speech in real-time.
|
||||
|
||||
```python
|
||||
stt.stream_speech("Hello world!")
|
||||
```
|
||||
|
||||
## 5. Exception Handling <a name="exception-handling"></a>
|
||||
The `ElevenLabsText2SpeechTool` handles exceptions gracefully. If an error occurs during the conversion process, it raises a `RuntimeError` with an informative error message.
|
||||
|
||||
## 6. Advanced Usage <a name="advanced-usage"></a>
|
||||
- You can implement custom error handling and logging to further enhance the functionality of this tool.
|
||||
- For advanced users, extending the class to support additional features or customization is possible.
|
||||
|
||||
## 7. Contributing <a name="contributing"></a>
|
||||
Contributions to this tool are welcome. Feel free to open issues, submit pull requests, or provide feedback to improve its functionality and documentation.
|
||||
|
||||
## 8. References <a name="references"></a>
|
||||
- [Eleven Labs Text2Speech API Documentation](https://docs.elevenlabs.io/welcome/introduction)
|
||||
|
||||
This documentation provides a comprehensive guide to using the `ElevenLabsText2SpeechTool`. It covers installation, basic usage, advanced features, and contribution guidelines. Refer to the [References](#references) section for additional resources.
|
@ -0,0 +1,81 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from swarms.models import Anthropic, OpenAIChat
|
||||
from swarms.prompts.accountant_swarm_prompts import (
|
||||
DECISION_MAKING_PROMPT,
|
||||
DOC_ANALYZER_AGENT_PROMPT,
|
||||
SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||
)
|
||||
from swarms.structs import Flow
|
||||
from swarms.utils.pdf_to_text import pdf_to_text
|
||||
|
||||
# Environment variables
|
||||
load_dotenv()
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
# Base llms
|
||||
llm1 = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
max_tokens=5000,
|
||||
)
|
||||
|
||||
llm2 = Anthropic(
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
max_tokens=5000,
|
||||
)
|
||||
|
||||
|
||||
# Agents
|
||||
doc_analyzer_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=DOC_ANALYZER_AGENT_PROMPT,
|
||||
max_loops=1,
|
||||
autosave=True,
|
||||
saved_state_path="doc_analyzer_agent.json",
|
||||
)
|
||||
summary_generator_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||
max_loops=1,
|
||||
autosave=True,
|
||||
saved_state_path="summary_generator_agent.json",
|
||||
)
|
||||
decision_making_support_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=DECISION_MAKING_PROMPT,
|
||||
max_loops=1,
|
||||
saved_state_path="decision_making_support_agent.json",
|
||||
)
|
||||
|
||||
|
||||
pdf_path = "bankstatement.pdf"
|
||||
fraud_detection_instructions = "Detect fraud in the document"
|
||||
summary_agent_instructions = (
|
||||
"Generate an actionable summary of the document with action steps to take"
|
||||
)
|
||||
decision_making_support_agent_instructions = (
|
||||
"Provide decision making support to the business owner:"
|
||||
)
|
||||
|
||||
|
||||
# Transform the pdf to text
|
||||
pdf_text = pdf_to_text(pdf_path)
|
||||
print(pdf_text)
|
||||
|
||||
|
||||
# Detect fraud in the document
|
||||
fraud_detection_agent_output = doc_analyzer_agent.run(
|
||||
f"{fraud_detection_instructions}: {pdf_text}"
|
||||
)
|
||||
|
||||
# Generate an actionable summary of the document
|
||||
summary_agent_output = summary_generator_agent.run(
|
||||
f"{summary_agent_instructions}: {fraud_detection_agent_output}"
|
||||
)
|
||||
|
||||
# Provide decision making support to the accountant
|
||||
decision_making_support_agent_output = decision_making_support_agent.run(
|
||||
f"{decision_making_support_agent_instructions}: {summary_agent_output}"
|
||||
)
|
@ -1,35 +1,117 @@
|
||||
import re
|
||||
from swarms.models.nougat import Nougat
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models import Anthropic, OpenAIChat
|
||||
from swarms.prompts.accountant_swarm_prompts import (
|
||||
DECISION_MAKING_PROMPT,
|
||||
DOC_ANALYZER_AGENT_PROMPT,
|
||||
FRAUD_DETECTION_AGENT_PROMPT,
|
||||
SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||
)
|
||||
from swarms.structs import Flow
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.models import LayoutLMDocumentQA
|
||||
from swarms.utils.pdf_to_text import pdf_to_text
|
||||
|
||||
# # URL of the image of the financial document
|
||||
IMAGE_OF_FINANCIAL_DOC_URL = "bank_statement_2.jpg"
|
||||
# Environment variables
|
||||
load_dotenv()
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Example usage
|
||||
api_key = ""
|
||||
|
||||
# Initialize the language flow
|
||||
llm = OpenAIChat(
|
||||
openai_api_key=api_key,
|
||||
# Base llms
|
||||
llm1 = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
)
|
||||
|
||||
llm2 = Anthropic(
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
)
|
||||
|
||||
# LayoutLM Document QA
|
||||
pdf_analyzer = LayoutLMDocumentQA()
|
||||
|
||||
question = "What is the total amount of expenses?"
|
||||
answer = pdf_analyzer(
|
||||
question,
|
||||
IMAGE_OF_FINANCIAL_DOC_URL,
|
||||
# Agents
|
||||
doc_analyzer_agent = Flow(
|
||||
llm=llm1,
|
||||
sop=DOC_ANALYZER_AGENT_PROMPT,
|
||||
)
|
||||
summary_generator_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||
)
|
||||
decision_making_support_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=DECISION_MAKING_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class AccountantSwarms:
|
||||
"""
|
||||
Accountant Swarms is a collection of agents that work together to help
|
||||
accountants with their work.
|
||||
|
||||
Flow: analyze doc -> detect fraud -> generate summary -> decision making support
|
||||
|
||||
The agents are:
|
||||
- User Consultant: Asks the user many questions
|
||||
- Document Analyzer: Extracts text from the image of the financial document
|
||||
- Fraud Detection: Detects fraud in the document
|
||||
- Summary Agent: Generates an actionable summary of the document
|
||||
- Decision Making Support: Provides decision making support to the accountant
|
||||
|
||||
# Initialize the Flow with the language flow
|
||||
agent = Flow(llm=llm)
|
||||
SUMMARY_AGENT_PROMPT = f"""
|
||||
Generate an actionable summary of this financial document be very specific and precise, provide bulletpoints be very specific provide methods of lowering expenses: {answer}"
|
||||
"""
|
||||
The agents are connected together in a workflow that is defined in the
|
||||
run method.
|
||||
|
||||
# Add tasks to the workflow
|
||||
summary_agent = agent.run(SUMMARY_AGENT_PROMPT)
|
||||
print(summary_agent)
|
||||
The workflow is as follows:
|
||||
1. The Document Analyzer agent extracts text from the image of the
|
||||
financial document.
|
||||
2. The Fraud Detection agent detects fraud in the document.
|
||||
3. The Summary Agent generates an actionable summary of the document.
|
||||
4. The Decision Making Support agent provides decision making support
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pdf_path: str,
|
||||
list_pdfs: List[str] = None,
|
||||
fraud_detection_instructions: str = None,
|
||||
summary_agent_instructions: str = None,
|
||||
decision_making_support_agent_instructions: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pdf_path = pdf_path
|
||||
self.list_pdfs = list_pdfs
|
||||
self.fraud_detection_instructions = fraud_detection_instructions
|
||||
self.summary_agent_instructions = summary_agent_instructions
|
||||
self.decision_making_support_agent_instructions = (
|
||||
decision_making_support_agent_instructions
|
||||
)
|
||||
|
||||
def run(self):
|
||||
# Transform the pdf to text
|
||||
pdf_text = pdf_to_text(self.pdf_path)
|
||||
|
||||
# Detect fraud in the document
|
||||
fraud_detection_agent_output = doc_analyzer_agent.run(
|
||||
f"{self.fraud_detection_instructions}: {pdf_text}"
|
||||
)
|
||||
|
||||
# Generate an actionable summary of the document
|
||||
summary_agent_output = summary_generator_agent.run(
|
||||
f"{self.summary_agent_instructions}: {fraud_detection_agent_output}"
|
||||
)
|
||||
|
||||
# Provide decision making support to the accountant
|
||||
decision_making_support_agent_output = decision_making_support_agent.run(
|
||||
f"{self.decision_making_support_agent_instructions}: {summary_agent_output}"
|
||||
)
|
||||
|
||||
return decision_making_support_agent_output
|
||||
|
||||
|
||||
swarm = AccountantSwarms(
|
||||
pdf_path="tesla.pdf",
|
||||
fraud_detection_instructions="Detect fraud in the document",
|
||||
summary_agent_instructions="Generate an actionable summary of the document",
|
||||
decision_making_support_agent_instructions="Provide decision making support to the business owner:",
|
||||
)
|
||||
|
@ -0,0 +1,53 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models import Anthropic, OpenAIChat
|
||||
from swarms.prompts.ai_research_team import (
|
||||
PAPER_IMPLEMENTOR_AGENT_PROMPT,
|
||||
PAPER_SUMMARY_ANALYZER,
|
||||
)
|
||||
from swarms.structs import Flow
|
||||
from swarms.utils.pdf_to_text import pdf_to_text
|
||||
|
||||
# Base llms
|
||||
# Environment variables
|
||||
load_dotenv()
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
PDF_PATH = "fasterffn.pdf"
|
||||
|
||||
|
||||
# Base llms
|
||||
llm1 = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
)
|
||||
|
||||
llm2 = Anthropic(
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
)
|
||||
|
||||
# Agents
|
||||
paper_summarizer_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=PAPER_SUMMARY_ANALYZER,
|
||||
max_loops=1,
|
||||
autosave=True,
|
||||
saved_state_path="paper_summarizer.json",
|
||||
)
|
||||
|
||||
paper_implementor_agent = Flow(
|
||||
llm=llm1,
|
||||
sop=PAPER_IMPLEMENTOR_AGENT_PROMPT,
|
||||
max_loops=1,
|
||||
autosave=True,
|
||||
saved_state_path="paper_implementor.json",
|
||||
code_interpreter=False,
|
||||
)
|
||||
|
||||
paper = pdf_to_text(PDF_PATH)
|
||||
algorithmic_psuedocode_agent = paper_summarizer_agent.run(
|
||||
f"Focus on creating the algorithmic pseudocode for the novel method in this paper: {paper}"
|
||||
)
|
||||
pytorch_code = paper_implementor_agent.run(algorithmic_psuedocode_agent)
|
Binary file not shown.
@ -0,0 +1,10 @@
|
||||
from swarms import Flow, Fuyu
|
||||
|
||||
llm = Fuyu()
|
||||
|
||||
flow = Flow(max_loops="auto", llm=llm)
|
||||
|
||||
flow.run(
|
||||
task="Describe this image in a few sentences: ",
|
||||
img="https://unsplash.com/photos/0pIC5ByPpZY",
|
||||
)
|
@ -0,0 +1,14 @@
|
||||
# This might not work in the beginning but it's a starting point
|
||||
from swarms.structs import Flow, GPT4V
|
||||
|
||||
llm = GPT4V()
|
||||
|
||||
flow = Flow(
|
||||
max_loops="auto",
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
flow.run(
|
||||
task="Describe this image in a few sentences: ",
|
||||
img="https://unsplash.com/photos/0pIC5ByPpZY",
|
||||
)
|
@ -1,15 +0,0 @@
|
||||
from swarms import WorkerUltraUltraNode
|
||||
|
||||
# Define an objective
|
||||
objective = """
|
||||
Please make a web GUI for using HTTP API server.
|
||||
The name of it is Swarms.
|
||||
You can check the server code at ./main.py.
|
||||
The server is served on localhost:8000.
|
||||
Users should be able to write text input as 'query' and url array as 'files', and check the response.
|
||||
Users input form should be delivered in JSON format.
|
||||
I want it to have neumorphism-style. Serve it on port 4500.
|
||||
"""
|
||||
|
||||
node = WorkerUltraUltraNode(objective)
|
||||
result = node.execute()
|
@ -1,17 +0,0 @@
|
||||
from langchain.models import OpenAIChat
|
||||
from swarms import Worker
|
||||
|
||||
llm = OpenAIChat(model_name="gpt-4", openai_api_key="api-key", temperature=0.5)
|
||||
|
||||
node = Worker(
|
||||
llm=llm,
|
||||
ai_name="Optimus Prime",
|
||||
ai_role="Worker in a swarm",
|
||||
external_tools=None,
|
||||
human_in_the_loop=False,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times."
|
||||
response = node.run(task)
|
||||
print(response)
|
@ -1,15 +0,0 @@
|
||||
from swarms import worker_node
|
||||
|
||||
# Your OpenAI API key
|
||||
api_key = "sksdsds"
|
||||
|
||||
# Initialize a WorkerNode with your API key
|
||||
node = worker_node(api_key)
|
||||
|
||||
# Define an objective
|
||||
objective = "Please make a web GUI for using HTTP API server..."
|
||||
|
||||
# Run the task
|
||||
task = node.run(objective)
|
||||
|
||||
print(task)
|
@ -1,25 +0,0 @@
|
||||
import os
|
||||
from swarms.swarms.swarms import WorkerUltra
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Define an objective
|
||||
objective = """
|
||||
Please make a web GUI for using HTTP API server.
|
||||
The name of it is Swarms.
|
||||
You can check the server code at ./main.py.
|
||||
The server is served on localhost:8000.
|
||||
Users should be able to write text input as 'query' and url array as 'files', and check the response.
|
||||
Users input form should be delivered in JSON format.
|
||||
I want it to have neumorphism-style. Serve it on port 4500.
|
||||
|
||||
"""
|
||||
|
||||
# Create an instance of WorkerUltra
|
||||
worker = WorkerUltra(objective, api_key)
|
||||
|
||||
# Execute the task
|
||||
result = worker.execute()
|
||||
|
||||
# Print the result
|
||||
print(result)
|
@ -1,517 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
import zipfile
|
||||
from tempfile import mkdtemp
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyautogui
|
||||
import requests
|
||||
import semver
|
||||
import undetected_chromedriver as uc # type: ignore
|
||||
import yaml
|
||||
from extension import load_extension
|
||||
from pydantic import BaseModel
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.webdriver.remote.webelement import WebElement
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _is_blank_agent(agent_name: str) -> bool:
|
||||
with open(f"agents/{agent_name}.py", "r") as agent_file:
|
||||
agent_data = agent_file.read()
|
||||
with open("src/template.py", "r") as template_file:
|
||||
template_data = template_file.read()
|
||||
return agent_data == template_data
|
||||
|
||||
|
||||
def record(agent_name: str, autotab_ext_path: Optional[str] = None):
|
||||
if not os.path.exists("agents"):
|
||||
os.makedirs("agents")
|
||||
|
||||
if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local":
|
||||
if not _is_blank_agent(agent_name=agent_name):
|
||||
raise Exception(f"Agent with name {agent_name} already exists")
|
||||
driver = get_driver( # noqa: F841
|
||||
autotab_ext_path=autotab_ext_path,
|
||||
record_mode=True,
|
||||
)
|
||||
# Need to keep a reference to the driver so that it doesn't get garbage collected
|
||||
with open("src/template.py", "r") as file:
|
||||
data = file.read()
|
||||
|
||||
with open(f"agents/{agent_name}.py", "w") as file:
|
||||
file.write(data)
|
||||
|
||||
print(
|
||||
"\033[34mYou have the Python debugger open, you can run commands in it like you"
|
||||
" would in a normal Python shell.\033[0m"
|
||||
)
|
||||
print(
|
||||
"\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and"
|
||||
" press enter.\033[0m"
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
record("agent")
|
||||
|
||||
|
||||
def extract_domain_from_url(url: str):
|
||||
# url = http://username:password@hostname:port/path?arg=value#anchor
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
if hostname is None:
|
||||
raise ValueError(f"Could not extract hostname from url {url}")
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
return hostname
|
||||
|
||||
|
||||
class AutotabChromeDriver(uc.Chrome):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def find_element_with_retry(
|
||||
self, by=By.ID, value: Optional[str] = None
|
||||
) -> WebElement:
|
||||
try:
|
||||
return super().find_element(by, value)
|
||||
except Exception as e:
|
||||
# TODO: Use an LLM to retry, finding a similar element on the DOM
|
||||
breakpoint()
|
||||
raise e
|
||||
|
||||
|
||||
def open_plugin(driver: AutotabChromeDriver):
|
||||
print("Opening plugin sidepanel")
|
||||
driver.execute_script("document.activeElement.blur();")
|
||||
pyautogui.press("esc")
|
||||
pyautogui.hotkey("command", "shift", "y", interval=0.05) # mypy: ignore
|
||||
|
||||
|
||||
def open_plugin_and_login(driver: AutotabChromeDriver):
|
||||
if config.autotab_api_key is not None:
|
||||
backend_url = (
|
||||
"http://localhost:8000"
|
||||
if config.environment == "local"
|
||||
else "https://api.autotab.com"
|
||||
)
|
||||
driver.get(f"{backend_url}/auth/signin-api-key-page")
|
||||
response = requests.post(
|
||||
f"{backend_url}/auth/signin-api-key",
|
||||
json={"api_key": config.autotab_api_key},
|
||||
)
|
||||
cookie = response.json()
|
||||
if response.status_code != 200:
|
||||
if response.status_code == 401:
|
||||
raise Exception("Invalid API key")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error {response.status_code} from backend while logging you in"
|
||||
f" with your API key: {response.text}"
|
||||
)
|
||||
cookie["name"] = cookie["key"]
|
||||
del cookie["key"]
|
||||
driver.add_cookie(cookie)
|
||||
|
||||
driver.get("https://www.google.com")
|
||||
open_plugin(driver)
|
||||
else:
|
||||
print("No autotab API key found, heading to autotab.com to sign up")
|
||||
|
||||
url = (
|
||||
"http://localhost:3000/dashboard"
|
||||
if config.environment == "local"
|
||||
else "https://autotab.com/dashboard"
|
||||
)
|
||||
driver.get(url)
|
||||
time.sleep(0.5)
|
||||
|
||||
open_plugin(driver)
|
||||
|
||||
|
||||
def get_driver(
|
||||
autotab_ext_path: Optional[str] = None, record_mode: bool = False
|
||||
) -> AutotabChromeDriver:
|
||||
options = webdriver.ChromeOptions()
|
||||
options.add_argument("--no-sandbox") # Necessary for running
|
||||
options.add_argument(
|
||||
"--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
" (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36"
|
||||
)
|
||||
options.add_argument("--enable-webgl")
|
||||
options.add_argument("--enable-3d-apis")
|
||||
options.add_argument("--enable-clipboard-read-write")
|
||||
options.add_argument("--disable-popup-blocking")
|
||||
|
||||
if autotab_ext_path is None:
|
||||
load_extension()
|
||||
options.add_argument("--load-extension=./src/extension/autotab")
|
||||
else:
|
||||
options.add_argument(f"--load-extension={autotab_ext_path}")
|
||||
|
||||
options.add_argument("--allow-running-insecure-content")
|
||||
options.add_argument("--disable-web-security")
|
||||
options.add_argument(f"--user-data-dir={mkdtemp()}")
|
||||
options.binary_location = config.chrome_binary_location
|
||||
driver = AutotabChromeDriver(options=options)
|
||||
if record_mode:
|
||||
open_plugin_and_login(driver)
|
||||
|
||||
return driver
|
||||
|
||||
|
||||
class SiteCredentials(BaseModel):
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
login_with_google_account: Optional[str] = None
|
||||
login_url: Optional[str] = None
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
if self.name is None:
|
||||
self.name = self.email
|
||||
|
||||
|
||||
class GoogleCredentials(BaseModel):
|
||||
credentials: Dict[str, SiteCredentials]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
for cred in self.credentials.values():
|
||||
cred.login_url = "https://accounts.google.com/v3/signin"
|
||||
|
||||
@property
|
||||
def default(self) -> SiteCredentials:
|
||||
if "default" not in self.credentials:
|
||||
if len(self.credentials) == 1:
|
||||
return list(self.credentials.values())[0]
|
||||
raise Exception("No default credentials found in config")
|
||||
return self.credentials["default"]
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
autotab_api_key: Optional[str]
|
||||
credentials: Dict[str, SiteCredentials]
|
||||
google_credentials: GoogleCredentials
|
||||
chrome_binary_location: str
|
||||
environment: str
|
||||
|
||||
@classmethod
|
||||
def load_from_yaml(cls, path: str):
|
||||
with open(path, "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
_credentials = {}
|
||||
for domain, creds in config.get("credentials", {}).items():
|
||||
if "login_url" not in creds:
|
||||
creds["login_url"] = f"https://{domain}/login"
|
||||
site_creds = SiteCredentials(**creds)
|
||||
_credentials[domain] = site_creds
|
||||
for alt in creds.get("alts", []):
|
||||
_credentials[alt] = site_creds
|
||||
|
||||
google_credentials = {}
|
||||
for creds in config.get("google_credentials", []):
|
||||
credentials: SiteCredentials = SiteCredentials(**creds)
|
||||
google_credentials[credentials.name] = credentials
|
||||
|
||||
chrome_binary_location = config.get("chrome_binary_location")
|
||||
if chrome_binary_location is None:
|
||||
raise Exception("Must specify chrome_binary_location in config")
|
||||
|
||||
autotab_api_key = config.get("autotab_api_key")
|
||||
if autotab_api_key == "...":
|
||||
autotab_api_key = None
|
||||
|
||||
return cls(
|
||||
autotab_api_key=autotab_api_key,
|
||||
credentials=_credentials,
|
||||
google_credentials=GoogleCredentials(credentials=google_credentials),
|
||||
chrome_binary_location=config.get("chrome_binary_location"),
|
||||
environment=config.get("environment", "prod"),
|
||||
)
|
||||
|
||||
def get_site_credentials(self, domain: str) -> SiteCredentials:
|
||||
credentials = self.credentials[domain].copy()
|
||||
return credentials
|
||||
|
||||
|
||||
config = Config.load_from_yaml(".autotab.yaml")
|
||||
|
||||
|
||||
def is_signed_in_to_google(driver):
|
||||
cookies = driver.get_cookies()
|
||||
return len([c for c in cookies if c["name"] == "SAPISID"]) != 0
|
||||
|
||||
|
||||
def google_login(
|
||||
driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True
|
||||
):
|
||||
print("Logging in to Google")
|
||||
if navigate:
|
||||
driver.get("https://accounts.google.com/")
|
||||
time.sleep(1)
|
||||
if is_signed_in_to_google(driver):
|
||||
print("Already signed in to Google")
|
||||
return
|
||||
|
||||
if os.path.exists("google_cookies.json"):
|
||||
print("cookies exist, doing loading")
|
||||
with open("google_cookies.json", "r") as f:
|
||||
google_cookies = json.load(f)
|
||||
for cookie in google_cookies:
|
||||
if "expiry" in cookie:
|
||||
cookie["expires"] = cookie["expiry"]
|
||||
del cookie["expiry"]
|
||||
driver.execute_cdp_cmd("Network.setCookie", cookie)
|
||||
time.sleep(1)
|
||||
driver.refresh()
|
||||
time.sleep(2)
|
||||
|
||||
if not credentials:
|
||||
credentials = config.google_credentials.default
|
||||
|
||||
if credentials is None:
|
||||
raise Exception("No credentials provided for Google login")
|
||||
|
||||
email_input = driver.find_element(By.CSS_SELECTOR, "[type='email']")
|
||||
email_input.send_keys(credentials.email)
|
||||
email_input.send_keys(Keys.ENTER)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))
|
||||
)
|
||||
|
||||
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
|
||||
password_input.send_keys(credentials.password)
|
||||
password_input.send_keys(Keys.ENTER)
|
||||
time.sleep(1.5)
|
||||
print("Successfully logged in to Google")
|
||||
|
||||
cookies = driver.get_cookies()
|
||||
if not is_signed_in_to_google(driver):
|
||||
# Probably wanted to have us solve a captcha, or 2FA or confirm recovery details
|
||||
print("Need 2FA help to log in to Google")
|
||||
# TODO: Show screenshot it to the user
|
||||
breakpoint()
|
||||
|
||||
if not os.path.exists("google_cookies.json"):
|
||||
print("Setting Google cookies for future use")
|
||||
# Log out to have access to the right cookies
|
||||
driver.get("https://accounts.google.com/Logout")
|
||||
time.sleep(2)
|
||||
cookies = driver.get_cookies()
|
||||
cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"]
|
||||
google_cookies = [
|
||||
cookie
|
||||
for cookie in cookies
|
||||
if cookie["domain"] in [".google.com", "accounts.google.com"]
|
||||
and cookie["name"] in cookie_names
|
||||
]
|
||||
with open("google_cookies.json", "w") as f:
|
||||
json.dump(google_cookies, f)
|
||||
|
||||
# Log back in
|
||||
login_button = driver.find_element(
|
||||
By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']"
|
||||
)
|
||||
login_button.click()
|
||||
time.sleep(1)
|
||||
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
|
||||
password_input.send_keys(credentials.password)
|
||||
password_input.send_keys(Keys.ENTER)
|
||||
|
||||
time.sleep(3)
|
||||
print("Successfully copied Google cookies for the future")
|
||||
|
||||
|
||||
def login(driver, url: str):
|
||||
domain = extract_domain_from_url(url)
|
||||
|
||||
credentials = config.get_site_credentials(domain)
|
||||
login_url = credentials.login_url
|
||||
if credentials.login_with_google_account:
|
||||
google_credentials = config.google_credentials.credentials[
|
||||
credentials.login_with_google_account
|
||||
]
|
||||
_login_with_google(driver, login_url, google_credentials)
|
||||
else:
|
||||
_login(driver, login_url, credentials=credentials)
|
||||
|
||||
|
||||
def _login(driver, url: str, credentials: SiteCredentials):
|
||||
print(f"Logging in to {url}")
|
||||
driver.get(url)
|
||||
time.sleep(2)
|
||||
email_input = driver.find_element(By.NAME, "email")
|
||||
email_input.send_keys(credentials.email)
|
||||
password_input = driver.find_element(By.NAME, "password")
|
||||
password_input.send_keys(credentials.password)
|
||||
password_input.send_keys(Keys.ENTER)
|
||||
|
||||
time.sleep(3)
|
||||
print(f"Successfully logged in to {url}")
|
||||
|
||||
|
||||
def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
|
||||
print(f"Logging in to {url} with Google")
|
||||
|
||||
google_login(driver, credentials=google_credentials)
|
||||
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
main_window = driver.current_window_handle
|
||||
xpath = (
|
||||
"//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with"
|
||||
" Google') or contains(@title, 'Sign in with Google')]"
|
||||
)
|
||||
|
||||
WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath)))
|
||||
driver.find_element(
|
||||
By.XPATH,
|
||||
xpath,
|
||||
).click()
|
||||
|
||||
driver.switch_to.window(driver.window_handles[-1])
|
||||
driver.find_element(
|
||||
By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]"
|
||||
).click()
|
||||
|
||||
driver.switch_to.window(main_window)
|
||||
|
||||
time.sleep(5)
|
||||
print(f"Successfully logged in to {url}")
|
||||
|
||||
|
||||
def update():
|
||||
print("updating extension...")
|
||||
# Download the autotab.crx file
|
||||
response = requests.get(
|
||||
"https://github.com/Planetary-Computers/autotab-extension/raw/main/autotab.crx",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Check if the directory exists, if not create it
|
||||
if os.path.exists("src/extension/.autotab"):
|
||||
shutil.rmtree("src/extension/.autotab")
|
||||
os.makedirs("src/extension/.autotab")
|
||||
|
||||
# Open the file in write binary mode
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
t = tqdm(total=total_size, unit="iB", unit_scale=True)
|
||||
with open("src/extension/.autotab/autotab.crx", "wb") as f:
|
||||
for data in response.iter_content(block_size):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
t.close()
|
||||
if total_size != 0 and t.n != total_size:
|
||||
print("ERROR, something went wrong")
|
||||
|
||||
# Unzip the file
|
||||
with zipfile.ZipFile("src/extension/.autotab/autotab.crx", "r") as zip_ref:
|
||||
zip_ref.extractall("src/extension/.autotab")
|
||||
os.remove("src/extension/.autotab/autotab.crx")
|
||||
if os.path.exists("src/extension/autotab"):
|
||||
shutil.rmtree("src/extension/autotab")
|
||||
os.rename("src/extension/.autotab", "src/extension/autotab")
|
||||
|
||||
|
||||
def should_update():
|
||||
if not os.path.exists("src/extension/autotab"):
|
||||
return True
|
||||
# Fetch the XML file
|
||||
response = requests.get(
|
||||
"https://raw.githubusercontent.com/Planetary-Computers/autotab-extension/main/update.xml"
|
||||
)
|
||||
xml_content = response.content
|
||||
|
||||
# Parse the XML file
|
||||
root = ET.fromstring(xml_content)
|
||||
namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces
|
||||
xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version")
|
||||
|
||||
# Load the local JSON file
|
||||
with open("src/extension/autotab/manifest.json", "r") as f:
|
||||
json_content = json.load(f)
|
||||
json_version = json_content["version"]
|
||||
# Compare versions
|
||||
return semver.compare(xml_version, json_version) > 0
|
||||
|
||||
|
||||
def load_extension():
|
||||
should_update() and update()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("should update:", should_update())
|
||||
update()
|
||||
|
||||
|
||||
def play(agent_name: Optional[str] = None):
|
||||
if agent_name is None:
|
||||
agent_files = os.listdir("agents")
|
||||
if len(agent_files) == 0:
|
||||
raise Exception("No agents found in agents/ directory")
|
||||
elif len(agent_files) == 1:
|
||||
agent_file = agent_files[0]
|
||||
else:
|
||||
print("Found multiple agent files, please select one:")
|
||||
for i, file in enumerate(agent_files, start=1):
|
||||
print(f"{i}. {file}")
|
||||
|
||||
selected = int(input("Select a file by number: ")) - 1
|
||||
agent_file = agent_files[selected]
|
||||
else:
|
||||
agent_file = f"{agent_name}.py"
|
||||
|
||||
os.system(f"python agents/{agent_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
play()
|
||||
"""
|
||||
|
||||
|
||||
chrome_binary_location: /Applications/Google Chrome.app/Contents/MacOS/Google Chrome
|
||||
|
||||
autotab_api_key: ... # Go to https://autotab.com/dashboard to get your API key, or
|
||||
# run `autotab record` with this field blank and you will be prompted to log in to autotab
|
||||
|
||||
# Optional, programmatically login to services using "Login with Google" authentication
|
||||
google_credentials:
|
||||
- name: default
|
||||
email: ...
|
||||
password: ...
|
||||
|
||||
# Optional, specify alternative accounts to use with Google login on a per-service basis
|
||||
- email: you@gmail.com # Credentials without a name use email as key
|
||||
password: ...
|
||||
|
||||
credentials:
|
||||
notion.so:
|
||||
alts:
|
||||
- notion.com
|
||||
login_with_google_account: default
|
||||
|
||||
figma.com:
|
||||
email: ...
|
||||
password: ...
|
||||
|
||||
airtable.com:
|
||||
login_with_google_account: you@gmail.com
|
||||
"""
|
@ -1,81 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from attr import define, field, Factory
|
||||
from marshmallow import class_registry
|
||||
from marshmallow.exceptions import RegistryError
|
||||
|
||||
|
||||
@define
|
||||
class BaseArtifact(ABC):
|
||||
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
|
||||
name: str = field(
|
||||
default=Factory(lambda self: self.id, takes_self=True), kw_only=True
|
||||
)
|
||||
value: any = field()
|
||||
type: str = field(
|
||||
default=Factory(lambda self: self.__class__.__name__, takes_self=True),
|
||||
kw_only=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def value_to_bytes(cls, value: any) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
else:
|
||||
return str(value).encode()
|
||||
|
||||
@classmethod
|
||||
def value_to_dict(cls, value: any) -> dict:
|
||||
if isinstance(value, dict):
|
||||
dict_value = value
|
||||
else:
|
||||
dict_value = json.loads(value)
|
||||
|
||||
return {k: v for k, v in dict_value.items()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, artifact_dict: dict) -> BaseArtifact:
|
||||
from griptape.schemas import (
|
||||
TextArtifactSchema,
|
||||
InfoArtifactSchema,
|
||||
ErrorArtifactSchema,
|
||||
BlobArtifactSchema,
|
||||
CsvRowArtifactSchema,
|
||||
ListArtifactSchema,
|
||||
)
|
||||
|
||||
class_registry.register("TextArtifact", TextArtifactSchema)
|
||||
class_registry.register("InfoArtifact", InfoArtifactSchema)
|
||||
class_registry.register("ErrorArtifact", ErrorArtifactSchema)
|
||||
class_registry.register("BlobArtifact", BlobArtifactSchema)
|
||||
class_registry.register("CsvRowArtifact", CsvRowArtifactSchema)
|
||||
class_registry.register("ListArtifact", ListArtifactSchema)
|
||||
|
||||
try:
|
||||
return class_registry.get_class(artifact_dict["type"])().load(artifact_dict)
|
||||
except RegistryError:
|
||||
raise ValueError("Unsupported artifact type")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, artifact_str: str) -> BaseArtifact:
|
||||
return cls.from_dict(json.loads(artifact_str))
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> dict:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def __add__(self, other: BaseArtifact) -> BaseArtifact:
|
||||
...
|
@ -1,19 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from attr import define, field
|
||||
from swarms.artifacts.base import BaseArtifact
|
||||
|
||||
|
||||
@define(frozen=True)
|
||||
class ErrorArtifact(BaseArtifact):
|
||||
value: str = field(converter=str)
|
||||
|
||||
def __add__(self, other: ErrorArtifact) -> ErrorArtifact:
|
||||
return ErrorArtifact(self.value + other.value)
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.value
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
from griptape.schemas import ErrorArtifactSchema
|
||||
|
||||
return dict(ErrorArtifactSchema().dump(self))
|
@ -1,74 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pprint
|
||||
import json
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, StrictStr
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
"""
|
||||
|
||||
Artifact that has the task has been produced
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
|
||||
artifact_id: str
|
||||
ID of the artifact
|
||||
|
||||
file_name: str
|
||||
Filename of the artifact
|
||||
|
||||
relative_path: str
|
||||
Relative path of the artifact
|
||||
|
||||
|
||||
"""
|
||||
|
||||
artifact_id: StrictStr = Field(..., description="ID of the artifact")
|
||||
file_name: StrictStr = Field(..., description="Filename of the artifact")
|
||||
relative_path: Optional[StrictStr] = Field(
|
||||
None, description="Relative path of the artifact"
|
||||
)
|
||||
__properties = ["artifact_id", "file_name", "relative_path"]
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration"""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
validate_assignment = True
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias"""
|
||||
return pprint.pformat(self.dict(by_alias=True))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Artifact:
|
||||
"""Create an instance of Artifact from a json string"""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self):
|
||||
"""Returns the dict representation of the model"""
|
||||
_dict = self.dict(by_alias=True, exclude={}, exclude_none=True)
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict) -> Artifact:
|
||||
"""Create an instance of Artifact from a dict"""
|
||||
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return Artifact.parse_obj(obj)
|
||||
|
||||
_obj = Artifact.parse_obj(
|
||||
{
|
||||
"artifact_id": obj.get("artifact_id"),
|
||||
"file_name": obj.get("file_name"),
|
||||
"relative_path": obj.get("relative_path"),
|
||||
}
|
||||
)
|
||||
|
||||
return _obj
|
@ -0,0 +1,6 @@
|
||||
"""
|
||||
QDRANT MEMORY CLASS
|
||||
|
||||
|
||||
|
||||
"""
|
@ -0,0 +1,104 @@
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
from swarms.tools.tool import BaseTool
|
||||
|
||||
|
||||
def _import_elevenlabs() -> Any:
|
||||
try:
|
||||
import elevenlabs
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import elevenlabs, please install `pip install elevenlabs`."
|
||||
) from e
|
||||
return elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsModel(str, Enum):
|
||||
"""Models available for Eleven Labs Text2Speech."""
|
||||
|
||||
MULTI_LINGUAL = "eleven_multilingual_v1"
|
||||
MONO_LINGUAL = "eleven_monolingual_v1"
|
||||
|
||||
|
||||
class ElevenLabsText2SpeechTool(BaseTool):
|
||||
"""Tool that queries the Eleven Labs Text2Speech API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://docs.elevenlabs.io/welcome/introduction
|
||||
|
||||
Attributes:
|
||||
model (ElevenLabsModel): The model to use for text to speech.
|
||||
Defaults to ElevenLabsModel.MULTI_LINGUAL.
|
||||
name (str): The name of the tool. Defaults to "eleven_labs_text2speech".
|
||||
description (str): The description of the tool.
|
||||
Defaults to "A wrapper around Eleven Labs Text2Speech. Useful for when you need to convert text to speech. It supports multiple languages, including English, German, Polish, Spanish, Italian, French, Portuguese, and Hindi."
|
||||
|
||||
|
||||
Usage:
|
||||
>>> from swarms.models import ElevenLabsText2SpeechTool
|
||||
>>> stt = ElevenLabsText2SpeechTool()
|
||||
>>> speech_file = stt.run("Hello world!")
|
||||
>>> stt.play(speech_file)
|
||||
>>> stt.stream_speech("Hello world!")
|
||||
|
||||
"""
|
||||
|
||||
model: Union[ElevenLabsModel, str] = ElevenLabsModel.MULTI_LINGUAL
|
||||
|
||||
name: str = "eleven_labs_text2speech"
|
||||
description: str = (
|
||||
"A wrapper around Eleven Labs Text2Speech. "
|
||||
"Useful for when you need to convert text to speech. "
|
||||
"It supports multiple languages, including English, German, Polish, "
|
||||
"Spanish, Italian, French, Portuguese, and Hindi. "
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
_ = get_from_dict_or_env(values, "eleven_api_key", "ELEVEN_API_KEY")
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
task: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
elevenlabs = _import_elevenlabs()
|
||||
try:
|
||||
speech = elevenlabs.generate(text=task, model=self.model)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="bx", suffix=".wav", delete=False
|
||||
) as f:
|
||||
f.write(speech)
|
||||
return f.name
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
|
||||
|
||||
def play(self, speech_file: str) -> None:
|
||||
"""Play the text as speech."""
|
||||
elevenlabs = _import_elevenlabs()
|
||||
with open(speech_file, mode="rb") as f:
|
||||
speech = f.read()
|
||||
|
||||
elevenlabs.play(speech)
|
||||
|
||||
def stream_speech(self, query: str) -> None:
|
||||
"""Stream the text as speech as it is generated.
|
||||
Play the text in your speakers."""
|
||||
elevenlabs = _import_elevenlabs()
|
||||
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
|
||||
elevenlabs.stream(speech_stream)
|
||||
|
||||
def save(self, speech_file: str, path: str) -> None:
|
||||
"""Save the speech file to a path."""
|
||||
raise NotImplementedError("Saving not implemented for this tool.")
|
||||
|
||||
def __str__(self):
|
||||
return "ElevenLabsText2SpeechTool"
|
@ -0,0 +1,90 @@
|
||||
ONBOARDING_AGENT_PROMPT = """
|
||||
|
||||
Onboarding:
|
||||
|
||||
"As the Onboarding Agent, your role is critical in guiding new users, particularly tech-savvy entrepreneurs, through the initial stages of engaging with our advanced swarm technology services. Begin by welcoming users in a friendly, professional manner, setting a positive tone for the interaction. Your conversation should flow logically, starting with an introduction to our services and their potential benefits for the user's specific business context.
|
||||
|
||||
Inquire about their industry, delving into specifics such as the industry's current trends, challenges, and the role technology plays in their sector. Show expertise and understanding by using industry-specific terminology and referencing relevant technological advancements. Ask open-ended questions to encourage detailed responses, enabling you to gain a comprehensive understanding of their business needs and objectives.
|
||||
|
||||
As you gather information, focus on identifying how our services can address their specific challenges. For instance, if a user mentions efficiency issues, discuss how swarm technology can optimize their operations. Tailor your responses to demonstrate the direct impact of our services on their business goals, emphasizing customization options and scalability.
|
||||
|
||||
Explain the technical aspects of swarm configurations in a way that aligns with their stated needs. Use analogies or real-world examples to simplify complex concepts. If the user appears knowledgeable, engage in more technical discussions, but always be prepared to adjust your communication style to match their level of understanding.
|
||||
|
||||
Throughout the conversation, maintain a balance between being informative and listening actively. Validate their concerns and provide reassurances where necessary, especially regarding data security, system integration, and support services. Your objective is to build trust and confidence in our services.
|
||||
|
||||
Finally, guide them through the initial setup process. Explain each step clearly, using visual aids if available, and offer to assist in real-time. Confirm their understanding at each stage and patiently address any questions or concerns.
|
||||
|
||||
Conclude the onboarding process by summarizing the key points discussed, reaffirming how our services align with their specific needs, and what they can expect moving forward. Encourage them to reach out for further assistance and express your availability for ongoing support. Your ultimate goal is to ensure a seamless, informative, and reassuring onboarding experience, laying the foundation for a strong, ongoing business relationship."
|
||||
|
||||
##################
|
||||
|
||||
"""
|
||||
|
||||
|
||||
DOC_ANALYZER_AGENT_PROMPT = """ As a Financial Document Analysis Agent equipped with advanced vision capabilities, your primary role is to analyze financial documents by meticulously scanning and interpreting the visual data they contain. Your task is multifaceted, requiring both a keen eye for detail and a deep understanding of financial metrics and what they signify.
|
||||
|
||||
When presented with a financial document, such as a balance sheet, income statement, or cash flow statement, begin by identifying the layout and structure of the document. Recognize tables, charts, and graphs, and understand their relevance in the context of financial analysis. Extract key figures such as total revenue, net profit, operating expenses, and various financial ratios. Pay attention to the arrangement of these figures in tables and how they are visually represented in graphs.
|
||||
|
||||
Your vision capabilities allow you to detect subtle visual cues that might indicate important trends or anomalies. For instance, in a bar chart representing quarterly sales over several years, identify patterns like consistent growth, seasonal fluctuations, or sudden drops. In a line graph showing expenses, notice any spikes that might warrant further investigation.
|
||||
|
||||
Apart from numerical data, also focus on the textual components within the documents. Extract and comprehend written explanations or notes that accompany financial figures, as they often provide crucial context. For example, a note accompanying an expense report might explain a one-time expenditure that significantly impacted the company's financials for that period.
|
||||
|
||||
Go beyond mere data extraction and engage in a level of interpretation that synthesizes the visual and textual information into a coherent analysis. For instance, if the profit margins are shrinking despite increasing revenues, hypothesize potential reasons such as rising costs or changes in the market conditions.
|
||||
|
||||
As you process each document, maintain a focus on accuracy and reliability. Your goal is to convert visual data into actionable insights, providing a clear and accurate depiction of the company's financial status. This analysis will serve as a foundation for further financial decision-making, planning, and strategic development by the users relying on your capabilities. Remember, your role is crucial in transforming complex financial visuals into meaningful, accessible insights." ok we need to edit this prompt down so that it can extract all the prompt info from a financial transaction doc
|
||||
|
||||
"""
|
||||
|
||||
SUMMARY_GENERATOR_AGENT_PROMPT = """
|
||||
|
||||
Summarizer:
|
||||
|
||||
"As the Financial Summary Generation Agent, your task is to synthesize the complex data extracted by the vision model into clear, concise, and insightful summaries. Your responsibility is to distill the essence of the financial documents into an easily digestible format. Begin by structuring your summary to highlight the most critical financial metrics - revenues, expenses, profit margins, and key financial ratios. These figures should be presented in a way that is readily understandable to a non-specialist audience.
|
||||
|
||||
Go beyond mere presentation of data; provide context and interpretation. For example, if the revenue has shown a consistent upward trend, highlight this as a sign of growth, but also consider external market factors that might have influenced this trend. Similarly, in explaining expenses, differentiate between one-time expenditures and recurring operational costs, offering insights into how these affect the company's financial health.
|
||||
|
||||
Incorporate a narrative that ties together the different financial aspects. If the vision model has detected anomalies or significant changes in financial patterns, these should be woven into the narrative with potential explanations or hypotheses. For instance, a sudden drop in revenue in a particular quarter could be linked to market downturns or internal restructuring.
|
||||
|
||||
Your summary should also touch upon forward-looking aspects. Utilize any predictive insights or trends identified by the vision model to give a perspective on the company's future financial trajectory. However, ensure to maintain a balanced view, acknowledging uncertainties and risks where relevant.
|
||||
|
||||
Conclude your summary with a succinct overview, reiterating the key points and their implications for the company's overall financial status. Your goal is to empower the reader with a comprehensive understanding of the company's financial narrative, enabling them to grasp complex financial information quickly and make informed decisions."
|
||||
|
||||
##################
|
||||
|
||||
"""
|
||||
|
||||
FRAUD_DETECTION_AGENT_PROMPT = """
|
||||
|
||||
Fraud Detection:
|
||||
|
||||
"As the Fraud Detection Agent, your mission is to meticulously scrutinize financial documents for any signs of fraudulent activities. Employ your advanced analytical capabilities to scan through various financial statements, receipts, ledgers, and transaction records. Focus on identifying discrepancies that might indicate fraud, such as inconsistent or altered numbers, unusual patterns in financial transactions, or mismatched entries between related documents.
|
||||
|
||||
Your approach should be both systematic and detail-oriented. Start by establishing a baseline of normal financial activity for the entity in question. Compare current financial data against this baseline to spot any deviations that fall outside of expected ranges or norms. Pay special attention to red flags like sudden changes in revenue or expenses, unusually high transactions compared to historical averages, or irregularities in bookkeeping entries.
|
||||
|
||||
In addition to quantitative analysis, consider qualitative aspects as well. Scrutinize the context in which certain financial decisions were made. Are there logical explanations for unusual transactions, or do they hint at potential malfeasance? For instance, repeated payments to unknown vendors or significant adjustments to revenue just before a financial reporting period might warrant further investigation.
|
||||
|
||||
Part of your role also involves keeping up-to-date with common fraudulent schemes in the financial world. Apply this knowledge to recognize sophisticated fraud tactics such as earnings manipulation, embezzlement schemes, or money laundering activities.
|
||||
|
||||
Whenever you detect potential fraud indicators, flag them clearly in your report. Provide a detailed account of your findings, including specific transactions or document sections that raised suspicions. Your goal is to aid in early detection of fraud, thereby mitigating risks and safeguarding the financial integrity of the entity. Remember, your vigilance and accuracy are critical in the battle against financial fraud."
|
||||
|
||||
##################
|
||||
|
||||
"""
|
||||
|
||||
DECISION_MAKING_PROMPT = """
|
||||
|
||||
Actionable Decision-Making:
|
||||
|
||||
"As the Decision-Making Support Agent, your role is to assist users in making informed financial decisions based on the analysis provided by the Financial Document Analysis and Summary Generation Agents. You are to provide actionable advice and recommendations, grounded in the data but also considering broader business strategies and market conditions.
|
||||
|
||||
Begin by reviewing the financial summaries and analysis reports, understanding the key metrics and trends they highlight. Cross-reference this data with industry benchmarks, economic trends, and best practices to provide well-rounded advice. For instance, if the analysis indicates a strong cash flow position, you might recommend strategic investments or suggest areas for expansion.
|
||||
|
||||
Address potential risks and opportunities. If the analysis reveals certain vulnerabilities, like over-reliance on a single revenue stream, advise on diversification strategies or risk mitigation tactics. Conversely, if there are untapped opportunities, such as emerging markets or technological innovations, highlight these as potential growth areas.
|
||||
|
||||
Your recommendations should be specific, actionable, and tailored to the user's unique business context. Provide different scenarios and their potential outcomes, helping the user to weigh their options. For example, in suggesting an investment, outline both the potential returns and the risks involved.
|
||||
|
||||
Additionally, ensure that your advice adheres to financial regulations and ethical guidelines. Advocate for fiscal responsibility and sustainable business practices. Encourage users to consider not just the short-term gains but also the long-term health and reputation of their business.
|
||||
|
||||
Ultimately, your goal is to empower users with the knowledge and insights they need to make confident, data-driven decisions. Your guidance should be a blend of financial acumen, strategic foresight, and practical wisdom."
|
||||
|
||||
"""
|
@ -0,0 +1,185 @@
|
||||
# Agent process automation
|
||||
system_prompt_1 = """You are a RPA(Robotic Process Automation) agent, you can write and test a RPA-Python-Code to connect different APPs together to reach a specific user query.
|
||||
|
||||
RPA-Python-Code:
|
||||
1. Each actions and triggers of APPs are defined as Action/Trigger-Functions, once you provide the specific_params for a function, then we will implement and test it **with some features that can influence outside-world and is transparent to you**.
|
||||
2. A RPA process is implemented as a workflow-function. the mainWorkflow function is activated when the trigger's conditions are reached.
|
||||
3. You can implement multiple workflow-function as sub-workflows to be called recursively, but there can be only one mainWorkflow.
|
||||
4. We will automatically test the workflows and actions with the Pinned-Data afer you change the specific_params.
|
||||
|
||||
Action/Trigger-Function: All the functions have the same following parameters:
|
||||
1.integration_name: where this function is from. A integration represent a list of actions and triggers from a APP.
|
||||
2.resource_name: This is the second category of a integration.
|
||||
3.operation_name: This is the third category of a integration. (integration->resouce->operation)
|
||||
4.specific_params: This is a json field, you will only see how to given this field after the above fields are selected.
|
||||
5.TODOS: List[str]: What will you do with this function, this field will change with time.
|
||||
6.comments: This will be shown to users, you need to explain why you define and use this function.
|
||||
|
||||
Workflow-Function:
|
||||
1. Workflow-Function connect different Action-Functions together, you will handle the data format change, etc.
|
||||
2. You must always have a mainWorkflow, whose inputs are a Trigger-function's output. If you define multiple triggers, The mainWorkflow will be activated when one of the trigger are activated, you must handle data type changes.
|
||||
3. You can define multiple subworkflow-function, Which whose inputs are provided by other workflows, You need to handle data-formats.
|
||||
|
||||
Testing-When-Implementing: We will **automatically** test all your actions, triggers and workflows with the pinned input data **at each time** once you change it.
|
||||
1. Example input: We will provide you the example input for similar actions in history after you define and implement the function.
|
||||
2. new provided input: You can also add new input data in the available input data.
|
||||
3. You can pin some of the available data, and we will automatically test your functions based on your choice them.
|
||||
4. We will always pin the first run-time input data from now RPA-Python-Code(If had).
|
||||
5.Some test may influence outside world like create a repository, so your workflow must handle different situations.
|
||||
|
||||
Data-Format: We ensure all the input/output data in transparent action functions have the format of List of Json: [{...}], length > 0
|
||||
1.All items in the list have the same json schema. The transparent will be activated for each item in the input-data. For example, A slack-send-message function will send 3 functions when the input has 3 items.
|
||||
2.All the json item must have a "json" field, in which are some custom fields.
|
||||
3.Some functions' json items have a additional "binary" field, which contains raw data of images, csv, etc.
|
||||
4.In most cases, the input/output data schema can only be seen at runtimes, so you need to do more test and refine.
|
||||
|
||||
Java-Script-Expression:
|
||||
1.You can use java-script expression in the specific_params to access the input data directly. Use it by a string startswith "=", and provide expression inside a "{{...}}" block.
|
||||
2. Use "{{$json["xxx"]}}" to obtain the "json" field in each item of the input data.
|
||||
3. You can use expression in "string" , "number", "boolean" and "json" type, such as:
|
||||
string: "=Hello {{$json["name"]}}, you are {{$json["age"]}} years old
|
||||
boolean: "={{$json["age"] > 20}}"
|
||||
number: "={{$json["year"] + 10.5}}"
|
||||
json: "={ "new_age":{{$json["year"] + 5}} }"
|
||||
|
||||
For example, in slack-send-message. The input looks like:
|
||||
[
|
||||
{
|
||||
"json": {
|
||||
"name": "Alice",
|
||||
"age": 15,
|
||||
}
|
||||
},
|
||||
{
|
||||
"json": {
|
||||
"name": "Jack",
|
||||
"age": 20,
|
||||
}
|
||||
}
|
||||
]
|
||||
When you set the field "message text" as "=Hello {{$json["name"]}}, you are {{$json["age"]}} years old.", then the message will be send as:
|
||||
[
|
||||
"Hello Alice, you are 15 years old.",
|
||||
"Hello Jack, you are 20 years old.",
|
||||
]
|
||||
|
||||
Based on the above information, the full RPA-Python-Code looks like the following:
|
||||
```
|
||||
from transparent_server import transparent_action, tranparent_trigger
|
||||
|
||||
# Specific_params: After you give function_define, we will provide json schemas of specific_params here.
|
||||
# Avaliable_datas: All the avaliable Datas: data_1, data_2...
|
||||
# Pinned_data_ID: All the input data you pinned and there execution result
|
||||
# ID=1, output: xxx
|
||||
# ID=3, output: xxx
|
||||
# Runtime_input_data: The runtime input of this function(first time)
|
||||
# Runtime_output_data: The corresponding output
|
||||
def action_1(input_data: [{...}]):
|
||||
# comments: some comments to users. Always give/change this when defining and implmenting
|
||||
# TODOS:
|
||||
# 1. I will provide the information in runtime
|
||||
# 2. I will test the node
|
||||
# 3. ...Always give/change this when defining and implmenting
|
||||
specific_params = {
|
||||
"key_1": value_1,
|
||||
"key_2": [
|
||||
{
|
||||
"subkey_2": value_2,
|
||||
}
|
||||
],
|
||||
"key_3": {
|
||||
"subkey_3": value_3,
|
||||
},
|
||||
# You will implement this after function-define
|
||||
}
|
||||
function = transparent_action(integration=xxx, resource=yyy, operation=zzz)
|
||||
output_data = function.run(input_data=input_data, params=params)
|
||||
return output_data
|
||||
|
||||
def action_2(input_data: [{...}]): ...
|
||||
def action_3(input_data: [{...}]): ...
|
||||
def action_4(input_data: [{...}]): ...
|
||||
|
||||
# Specific_params: After you give function_define, we will provide json schemas of specific_params here.
|
||||
# Trigger function has no input, and have the same output_format. So We will provide You the exmaple_output once you changed the code here.
|
||||
def trigger_1():
|
||||
# comments: some comments to users. Always give/change this when defining and implmenting
|
||||
# TODOS:
|
||||
# 1. I will provide the information in runtime
|
||||
# 2. I will test the node
|
||||
# 3. ...Always give/change this when defining and implmenting
|
||||
specific_params = {
|
||||
"key_1": value_1,
|
||||
"key_2": [
|
||||
{
|
||||
"subkey_2": value_2,
|
||||
}
|
||||
],
|
||||
"key_3": {
|
||||
"subkey_3": value_3,
|
||||
},
|
||||
# You will implement this after function-define
|
||||
}
|
||||
function = transparent_trigger(integration=xxx, resource=yyy, operation=zzz)
|
||||
output_data = function.run(input_data=input_data, params=params)
|
||||
return output_data
|
||||
|
||||
def trigger_2(input_data: [{...}]): ...
|
||||
def trigger_3(input_data: [{...}]): ...
|
||||
|
||||
# subworkflow inputs the same json-schema, can be called by another workflow.
|
||||
def subworkflow_1(father_workflow_input: [{...}]): ...
|
||||
def subworkflow_2(father_workflow_input: [{...}]): ...
|
||||
|
||||
# If you defined the trigger node, we will show you the mocked trigger input here.
|
||||
# If you have implemented the workflow, we will automatically run the workflow for all the mock trigger-input and tells you the result.
|
||||
def mainWorkflow(trigger_input: [{...}]):
|
||||
# comments: some comments to users. Always give/change this when defining and implmenting
|
||||
# TODOS:
|
||||
# 1. I will provide the information in runtime
|
||||
# 2. I will test the node
|
||||
# 3. ...Always give/change this when defining and implmenting
|
||||
|
||||
# some complex logics here
|
||||
output_data = trigger_input
|
||||
|
||||
return output_data
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
system_prompt_2 = """You will define and implement functions progressively for many steps. At each step, you can do one of the following actions:
|
||||
1. functions_define: Define a list of functions(Action and Trigger). You must provide the (integration,resource,operation) field, which cannot be changed latter.
|
||||
2. function_implement: After function define, we will provide you the specific_param schema of the target function. You can provide(or override) the specific_param by this function. We will show your available test_data after you implement functions.
|
||||
3. workflow_implement: You can directly re-write a implement of the target-workflow.
|
||||
4. add_test_data: Beside the provided hostory data, you can also add your custom test data for a function.
|
||||
5. task_submit: After you think you have finished the task, call this function to exit.
|
||||
|
||||
Remember:
|
||||
1.Always provide thought, plans and criticisim before giving an action.
|
||||
2.Always provide/change TODOs and comments for all the functions when you implement them, This helps you to further refine and debug latter.
|
||||
3.We will test functions automatically, you only need to change the pinned data.
|
||||
|
||||
"""
|
||||
|
||||
system_prompt_3 = """The user query:
|
||||
{{user_query}}
|
||||
|
||||
You have access to use the following actions and triggers:
|
||||
|
||||
{{flatten_tools}}
|
||||
"""
|
||||
|
||||
history_prompt = """In the {{action_count}}'s time, You made the following action:
|
||||
{{action}}
|
||||
"""
|
||||
|
||||
user_prompt = """Now the codes looks like this:
|
||||
```
|
||||
{{now_codes}}
|
||||
```
|
||||
|
||||
{{refine_prompt}}
|
||||
|
||||
Give your next action together with thought, plans and criticisim:
|
||||
"""
|
@ -1,6 +1,5 @@
|
||||
from swarms.structs.workflow import Workflow
|
||||
from swarms.structs.task import Task
|
||||
from swarms.structs.flow import Flow
|
||||
from swarms.structs.sequential_workflow import SequentialWorkflow
|
||||
from swarms.structs.autoscaler import AutoScaler
|
||||
|
||||
__all__ = ["Workflow", "Task", "Flow", "SequentialWorkflow"]
|
||||
__all__ = ["Flow", "SequentialWorkflow", "AutoScaler"]
|
||||
|
@ -0,0 +1,79 @@
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.structs.flow import Flow
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Callable, List, Dict, Any, Sequence
|
||||
|
||||
|
||||
class Task:
|
||||
def __init__(self, id: str, task: str, flows: Sequence[Flow], dependencies: List[str] = []):
|
||||
self.id = id
|
||||
self.task = task
|
||||
self.flows = flows
|
||||
self.dependencies = dependencies
|
||||
self.results = []
|
||||
|
||||
def execute(self, parent_results: Dict[str, Any]):
|
||||
args = [parent_results[dep] for dep in self.dependencies]
|
||||
for flow in self.flows:
|
||||
result = flow.run(self.task, *args)
|
||||
self.results.append(result)
|
||||
args = [result] # The output of one flow becomes the input to the next
|
||||
|
||||
|
||||
class Workflow:
|
||||
def __init__(self):
|
||||
self.tasks: Dict[str, Task] = {}
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor()
|
||||
|
||||
def add_task(self, task: Task):
|
||||
self.tasks[task.id] = task
|
||||
|
||||
def run(self):
|
||||
completed_tasks = set()
|
||||
while len(completed_tasks) < len(self.tasks):
|
||||
futures = []
|
||||
for task in self.tasks.values():
|
||||
if task.id not in completed_tasks and all(
|
||||
dep in completed_tasks for dep in task.dependencies
|
||||
):
|
||||
future = self.executor.submit(
|
||||
task.execute,
|
||||
{dep: self.tasks[dep].results for dep in task.dependencies},
|
||||
)
|
||||
futures.append((future, task.id))
|
||||
|
||||
for future, task_id in futures:
|
||||
future.result() # Wait for task completion
|
||||
completed_tasks.add(task_id)
|
||||
|
||||
def get_results(self):
|
||||
return {task_id: task.results for task_id, task in self.tasks.items()}
|
||||
|
||||
|
||||
# create flows
|
||||
llm = OpenAIChat(openai_api_key="sk-")
|
||||
|
||||
flow1 = Flow(llm, max_loops=1)
|
||||
flow2 = Flow(llm, max_loops=1)
|
||||
flow3 = Flow(llm, max_loops=1)
|
||||
flow4 = Flow(llm, max_loops=1)
|
||||
|
||||
|
||||
# Create tasks with their respective Flows and task strings
|
||||
task1 = Task("task1", "Generate a summary on Quantum field theory", [flow1])
|
||||
task2 = Task("task2", "Elaborate on the summary of topic X", [flow2, flow3], dependencies=["task1"])
|
||||
task3 = Task("task3", "Generate conclusions for topic X", [flow4], dependencies=["task1"])
|
||||
|
||||
# Create a workflow and add tasks
|
||||
workflow = Workflow()
|
||||
workflow.add_task(task1)
|
||||
workflow.add_task(task2)
|
||||
workflow.add_task(task3)
|
||||
|
||||
# Run the workflow
|
||||
workflow.run()
|
||||
|
||||
# Get results
|
||||
results = workflow.get_results()
|
||||
print(results)
|
@ -1,105 +0,0 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from graphlib import TopologicalSorter
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Task:
|
||||
"""
|
||||
Task is a unit of work that can be executed by an agent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, id: str, parents: List["Task"] = None, children: List["Task"] = None
|
||||
):
|
||||
self.id = id
|
||||
self.parents = parents
|
||||
self.children = children
|
||||
|
||||
def can_execute(self):
|
||||
"""
|
||||
can_execute returns True if the task can be executed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute(self):
|
||||
"""
|
||||
Execute the task
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NonLinearWorkflow:
|
||||
"""
|
||||
NonLinearWorkflow constructs a non sequential DAG of tasks to be executed by agents
|
||||
|
||||
|
||||
Architecture:
|
||||
NonLinearWorkflow = Task + Agent + Executor
|
||||
|
||||
ASCII Diagram:
|
||||
+-------------------+
|
||||
| NonLinearWorkflow |
|
||||
+-------------------+
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
+-------------------+
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, agents, iters_per_task):
|
||||
"""A workflow is a collection of tasks that can be executed in parallel or sequentially."""
|
||||
super().__init__()
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.agents = agents
|
||||
self.tasks = []
|
||||
|
||||
def add(self, task: Task):
|
||||
"""Add a task to the workflow"""
|
||||
assert isinstance(task, Task), "Input must be an nstance of Task"
|
||||
self.tasks.append(task)
|
||||
return task
|
||||
|
||||
def run(self):
|
||||
"""Run the workflow"""
|
||||
ordered_tasks = self.ordered_tasks
|
||||
exit_loop = False
|
||||
|
||||
while not self.is_finished() and not exit_loop:
|
||||
futures_list = {}
|
||||
|
||||
for task in ordered_tasks:
|
||||
if task.can_execute:
|
||||
future = self.executor.submit(self.agents.run, task.task_string)
|
||||
futures_list[future] = task
|
||||
|
||||
for future in as_completed(futures_list):
|
||||
if isinstance(future.result(), Exception):
|
||||
exit_loop = True
|
||||
break
|
||||
return self.output_tasks()
|
||||
|
||||
def output_tasks(self) -> List[Task]:
|
||||
"""Output tasks from the workflow"""
|
||||
return [task for task in self.tasks if not task.children]
|
||||
|
||||
def to_graph(self) -> Dict[str, set[str]]:
|
||||
"""Convert the workflow to a graph"""
|
||||
graph = {
|
||||
task.id: set(child.id for child in task.children) for task in self.tasks
|
||||
}
|
||||
return graph
|
||||
|
||||
def order_tasks(self) -> List[Task]:
|
||||
"""Order the tasks USING TOPOLOGICAL SORTING"""
|
||||
task_order = TopologicalSorter(self.to_graph()).static_order()
|
||||
return [self.find_task(task_id) for task_id in task_order]
|
@ -1,174 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, StrictStr
|
||||
from swarms.artifacts.main import Artifact
|
||||
from swarms.artifacts.error_artifact import ErrorArtifact
|
||||
|
||||
|
||||
class BaseTask(ABC):
|
||||
class State(Enum):
|
||||
PENDING = 1
|
||||
EXECUTING = 2
|
||||
FINISHED = 3
|
||||
|
||||
def __init__(self):
|
||||
self.id: str = uuid.uuid4().hex
|
||||
self.state: BaseTask.State = self.State.PENDING
|
||||
self.parent_ids: List[str] = []
|
||||
self.child_ids: List[str] = []
|
||||
self.output: Optional[Union[Artifact, ErrorArtifact]] = None
|
||||
self.structure = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input(self) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def parents(self) -> List[BaseTask]:
|
||||
return [self.structure.find_task(parent_id) for parent_id in self.parent_ids]
|
||||
|
||||
@property
|
||||
def children(self) -> List[BaseTask]:
|
||||
return [self.structure.find_task(child_id) for child_id in self.child_ids]
|
||||
|
||||
def __rshift__(self, child: BaseTask) -> BaseTask:
|
||||
return self.add_child(child)
|
||||
|
||||
def __lshift__(self, child: BaseTask) -> BaseTask:
|
||||
return self.add_parent(child)
|
||||
|
||||
def preprocess(self, structure) -> BaseTask:
|
||||
self.structure = structure
|
||||
return self
|
||||
|
||||
def add_child(self, child: BaseTask) -> BaseTask:
|
||||
if self.structure:
|
||||
child.structure = self.structure
|
||||
elif child.structure:
|
||||
self.structure = child.structure
|
||||
|
||||
if child not in self.structure.tasks:
|
||||
self.structure.tasks.append(child)
|
||||
|
||||
if self not in self.structure.tasks:
|
||||
self.structure.tasks.append(self)
|
||||
|
||||
if child.id not in self.child_ids:
|
||||
self.child_ids.append(child.id)
|
||||
|
||||
if self.id not in child.parent_ids:
|
||||
child.parent_ids.append(self.id)
|
||||
|
||||
return child
|
||||
|
||||
def add_parent(self, parent: BaseTask) -> BaseTask:
|
||||
if self.structure:
|
||||
parent.structure = self.structure
|
||||
elif parent.structure:
|
||||
self.structure = parent.structure
|
||||
|
||||
if parent not in self.structure.tasks:
|
||||
self.structure.tasks.append(parent)
|
||||
|
||||
if self not in self.structure.tasks:
|
||||
self.structure.tasks.append(self)
|
||||
|
||||
if parent.id not in self.parent_ids:
|
||||
self.parent_ids.append(parent.id)
|
||||
|
||||
if self.id not in parent.child_ids:
|
||||
parent.child_ids.append(self.id)
|
||||
|
||||
return parent
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
return self.state == self.State.PENDING
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.state == self.State.FINISHED
|
||||
|
||||
def is_executing(self) -> bool:
|
||||
return self.state == self.State.EXECUTING
|
||||
|
||||
def before_run(self) -> None:
|
||||
pass
|
||||
|
||||
def after_run(self) -> None:
|
||||
pass
|
||||
|
||||
def execute(self) -> Optional[Union[Artifact, ErrorArtifact]]:
|
||||
try:
|
||||
self.state = self.State.EXECUTING
|
||||
self.before_run()
|
||||
self.output = self.run()
|
||||
self.after_run()
|
||||
except Exception as e:
|
||||
self.output = ErrorArtifact(str(e))
|
||||
finally:
|
||||
self.state = self.State.FINISHED
|
||||
return self.output
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
return self.state == self.State.PENDING and all(
|
||||
parent.is_finished() for parent in self.parents
|
||||
)
|
||||
|
||||
def reset(self) -> BaseTask:
|
||||
self.state = self.State.PENDING
|
||||
self.output = None
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> Optional[Union[Artifact, ErrorArtifact]]:
|
||||
pass
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
input: Optional[StrictStr] = Field(None, description="Input prompt for the task")
|
||||
additional_input: Optional[Any] = Field(
|
||||
None, description="Input parameters for the task. Any value is allowed"
|
||||
)
|
||||
task_id: StrictStr = Field(..., description="ID of the task")
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
validate_assignment = True
|
||||
|
||||
def to_str(self) -> str:
|
||||
return pprint.pformat(self.dict(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.dict(by_alias=True, exclude_none=True))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "Task":
|
||||
return cls.parse_raw(json_str)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
_dict = self.dict(by_alias=True, exclude_none=True)
|
||||
if self.artifacts:
|
||||
_dict["artifacts"] = [
|
||||
artifact.dict(by_alias=True, exclude_none=True)
|
||||
for artifact in self.artifacts
|
||||
]
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict) -> "Task":
|
||||
if obj is None:
|
||||
return None
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError("Input must be a dictionary.")
|
||||
if "artifacts" in obj:
|
||||
obj["artifacts"] = [
|
||||
Artifact.parse_obj(artifact) for artifact in obj["artifacts"]
|
||||
]
|
||||
return cls.parse_obj(obj)
|
@ -1,83 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
from swarms.structs.task import Task
|
||||
|
||||
|
||||
class Workflow:
|
||||
"""
|
||||
Workflows are ideal for prescriptive processes that need to be executed
|
||||
sequentially.
|
||||
They string together multiple tasks of varying types, and can use Short-Term Memory
|
||||
or pass specific arguments downstream.
|
||||
|
||||
Usage
|
||||
llm = LLM()
|
||||
workflow = Workflow(llm)
|
||||
|
||||
workflow.add("What's the weather in miami")
|
||||
workflow.add("Provide details for {{ parent_output }}")
|
||||
workflow.add("Summarize the above information: {{ parent_output}})
|
||||
|
||||
workflow.run()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, agent, parallel: bool = False):
|
||||
"""__init__"""
|
||||
self.agent = agent
|
||||
self.tasks: List[Task] = []
|
||||
self.parallel = parallel
|
||||
|
||||
def add(self, task: str) -> Task:
|
||||
"""Add a task"""
|
||||
task = Task(task_id=uuid.uuid4().hex, input=task)
|
||||
|
||||
if self.last_task():
|
||||
self.last_task().add_child(task)
|
||||
else:
|
||||
task.structure = self
|
||||
self.tasks.append(task)
|
||||
return task
|
||||
|
||||
def first_task(self) -> Optional[Task]:
|
||||
"""Add first task"""
|
||||
return self.tasks[0] if self.tasks else None
|
||||
|
||||
def last_task(self) -> Optional[Task]:
|
||||
"""Last task"""
|
||||
return self.tasks[-1] if self.tasks else None
|
||||
|
||||
def run(self, task: str) -> Task:
|
||||
"""Run tasks"""
|
||||
self.add(task)
|
||||
|
||||
if self.parallel:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
list(executor.map(self.__run_from_task, [self.first_task]))
|
||||
else:
|
||||
self.__run_from_task(self.first_task())
|
||||
|
||||
return self.last_task()
|
||||
|
||||
def context(self, task: Task) -> Dict[str, Any]:
|
||||
"""Context in tasks"""
|
||||
return {
|
||||
"parent_output": task.parents[0].output
|
||||
if task.parents and task.parents[0].output
|
||||
else None,
|
||||
"parent": task.parents[0] if task.parents else None,
|
||||
"child": task.children[0] if task.children else None,
|
||||
}
|
||||
|
||||
def __run_from_task(self, task: Optional[Task]) -> None:
|
||||
"""Run from task"""
|
||||
if task is None:
|
||||
return
|
||||
else:
|
||||
if isinstance(task.execute(), Exception):
|
||||
return
|
||||
else:
|
||||
self.__run_from_task(next(iter(task.children), None))
|
@ -1,17 +1,14 @@
|
||||
from swarms.swarms.dialogue_simulator import DialogueSimulator
|
||||
from swarms.swarms.autoscaler import AutoScaler
|
||||
|
||||
# from swarms.swarms.orchestrate import Orchestrator
|
||||
from swarms.structs.autoscaler import AutoScaler
|
||||
from swarms.swarms.god_mode import GodMode
|
||||
from swarms.swarms.simple_swarm import SimpleSwarm
|
||||
from swarms.swarms.multi_agent_debate import MultiAgentDebate, select_speaker
|
||||
from swarms.swarms.multi_agent_collab import MultiAgentCollaboration
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DialogueSimulator",
|
||||
"AutoScaler",
|
||||
# "Orchestrator",
|
||||
"GodMode",
|
||||
"SimpleSwarm",
|
||||
"MultiAgentDebate",
|
||||
"select_speaker",
|
||||
"MultiAgentCollaboration",
|
||||
]
|
||||
|
@ -1,76 +0,0 @@
|
||||
from swarms.structs.flow import Flow
|
||||
|
||||
|
||||
# Define a selection function
|
||||
def select_speaker(step: int, agents) -> int:
|
||||
# This function selects the speaker in a round-robin fashion
|
||||
return step % len(agents)
|
||||
|
||||
|
||||
class MultiAgentDebate:
|
||||
"""
|
||||
MultiAgentDebate
|
||||
|
||||
|
||||
Args:
|
||||
agents: Flow
|
||||
selection_func: callable
|
||||
max_iters: int
|
||||
|
||||
Usage:
|
||||
>>> from swarms import MultiAgentDebate
|
||||
>>> from swarms.structs.flow import Flow
|
||||
>>> agents = Flow()
|
||||
>>> agents.append(lambda x: x)
|
||||
>>> agents.append(lambda x: x)
|
||||
>>> agents.append(lambda x: x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: Flow,
|
||||
selection_func: callable = select_speaker,
|
||||
max_iters: int = None,
|
||||
):
|
||||
self.agents = agents
|
||||
self.selection_func = selection_func
|
||||
self.max_iters = max_iters
|
||||
|
||||
def inject_agent(self, agent):
|
||||
"""Injects an agent into the debate"""
|
||||
self.agents.append(agent)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task: str,
|
||||
):
|
||||
"""
|
||||
MultiAgentDebate
|
||||
|
||||
Args:
|
||||
task: str
|
||||
|
||||
Returns:
|
||||
results: list
|
||||
|
||||
"""
|
||||
results = []
|
||||
for i in range(self.max_iters or len(self.agents)):
|
||||
speaker_idx = self.selection_func(i, self.agents)
|
||||
speaker = self.agents[speaker_idx]
|
||||
response = speaker(task)
|
||||
results.append({"response": response})
|
||||
return results
|
||||
|
||||
def update_task(self, task: str):
|
||||
"""Update the task"""
|
||||
self.task = task
|
||||
|
||||
def format_results(self, results):
|
||||
"""Format the results"""
|
||||
formatted_results = "\n".join(
|
||||
[f"Agent responded: {result['response']}" for result in results]
|
||||
)
|
||||
|
||||
return formatted_results
|
@ -0,0 +1,154 @@
|
||||
from enum import Enum, unique, auto
|
||||
import abc
|
||||
import hashlib
|
||||
import re
|
||||
from typing import List, Optional
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@unique
|
||||
class LLMStatusCode(Enum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
|
||||
|
||||
@unique
|
||||
class NodeType(Enum):
|
||||
action = auto()
|
||||
trigger = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class WorkflowType(Enum):
|
||||
Main = auto()
|
||||
Sub = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class ToolCallStatus(Enum):
|
||||
ToolCallSuccess = auto()
|
||||
ToolCallPartlySuccess = auto()
|
||||
NoSuchTool = auto()
|
||||
NoSuchFunction = auto()
|
||||
InputCannotParsed = auto()
|
||||
|
||||
UndefinedParam = auto()
|
||||
ParamTypeError = auto()
|
||||
UnSupportedParam = auto()
|
||||
UnsupportedExpression = auto()
|
||||
ExpressionError = auto()
|
||||
RequiredParamUnprovided = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class TestDataType(Enum):
|
||||
NoInput = auto()
|
||||
TriggerInput = auto()
|
||||
ActionInput = auto()
|
||||
SubWorkflowInput = auto()
|
||||
|
||||
|
||||
@unique
|
||||
class RunTimeStatus(Enum):
|
||||
FunctionExecuteSuccess = auto()
|
||||
TriggerAcivatedSuccess = auto()
|
||||
ErrorRaisedHere = auto()
|
||||
ErrorRaisedInner = auto()
|
||||
DidNotImplemented = auto()
|
||||
DidNotBeenCalled = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""
|
||||
Responsible for handling the data structure of [{}]
|
||||
"""
|
||||
|
||||
data_type: TestDataType = TestDataType.ActionInput
|
||||
|
||||
input_data: Optional[list] = field(default_factory=lambda: [])
|
||||
|
||||
runtime_status: RunTimeStatus = RunTimeStatus.DidNotBeenCalled
|
||||
visit_times: int = 0
|
||||
|
||||
error_message: str = ""
|
||||
output_data: Optional[list] = field(default_factory=lambda: [])
|
||||
|
||||
def load_from_json(self):
|
||||
pass
|
||||
|
||||
def to_json(self):
|
||||
pass
|
||||
|
||||
def to_str(self):
|
||||
prompt = f"""
|
||||
This function has been executed for {self.visit_times} times. Last execution:
|
||||
1.Status: {self.runtime_status.name}
|
||||
2.Input:
|
||||
{self.input_data}
|
||||
|
||||
3.Output:
|
||||
{self.output_data}"""
|
||||
return prompt
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
content: str = ""
|
||||
thought: str = ""
|
||||
plan: List[str] = field(default_factory=lambda: [])
|
||||
criticism: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict = field(default_factory=lambda: {})
|
||||
|
||||
tool_output_status: ToolCallStatus = ToolCallStatus.ToolCallSuccess
|
||||
tool_output: str = ""
|
||||
|
||||
def to_json(self):
|
||||
try:
|
||||
tool_output = json.loads(self.tool_output)
|
||||
except:
|
||||
tool_output = self.tool_output
|
||||
return {
|
||||
"thought": self.thought,
|
||||
"plan": self.plan,
|
||||
"criticism": self.criticism,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_input": self.tool_input,
|
||||
"tool_output_status": self.tool_output_status.name,
|
||||
"tool_output": tool_output,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class userQuery:
|
||||
task: str
|
||||
additional_information: List[str] = field(default_factory=lambda: [])
|
||||
refine_prompt: str = field(default_factory=lambda: "")
|
||||
|
||||
def print_self(self):
|
||||
lines = [self.task]
|
||||
for info in self.additional_information:
|
||||
lines.append(f"- {info}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class Singleton(abc.ABCMeta, type):
|
||||
"""
|
||||
Singleton metaclass for ensuring only one instance of a class.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""Call method for the singleton metaclass."""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
||||
"""
|
||||
Abstract singleton class for ensuring only one instance of a class.
|
||||
"""
|
@ -1,23 +0,0 @@
|
||||
from rich import print as rich_print
|
||||
from rich.markdown import Markdown
|
||||
from rich.rule import Rule
|
||||
|
||||
|
||||
def display_markdown_message(message):
|
||||
"""
|
||||
Display markdown message. Works with multiline strings with lots of indentation.
|
||||
Will automatically make single line > tags beautiful.
|
||||
"""
|
||||
|
||||
for line in message.split("\n"):
|
||||
line = line.strip()
|
||||
if line == "":
|
||||
print("")
|
||||
elif line == "---":
|
||||
rich_print(Rule(style="white"))
|
||||
else:
|
||||
rich_print(Markdown(line))
|
||||
|
||||
if "\n" not in message and message.startswith(">"):
|
||||
# Aesthetic choice. For these tags, they need a space below them
|
||||
print("")
|
@ -0,0 +1,501 @@
|
||||
"""Logging modules"""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
from logging import LogRecord
|
||||
from typing import Any
|
||||
|
||||
from colorama import Fore, Style
|
||||
from swarms.utils.apa import Action, ToolCallStatus
|
||||
|
||||
|
||||
# from autogpt.speech import say_text
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
"""
|
||||
Initializes a new instance of the class with the specified file name, mode, encoding, and delay settings.
|
||||
|
||||
Parameters:
|
||||
filename (str): The name of the file to be opened.
|
||||
mode (str, optional): The mode in which the file is opened. Defaults to "a" (append).
|
||||
encoding (str, optional): The encoding used to read or write the file. Defaults to None.
|
||||
delay (bool, optional): If True, the file opening is delayed until the first IO operation. Defaults to False.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Writes the formatted log record to a JSON file.
|
||||
|
||||
Parameters:
|
||||
record (LogRecord): The log record to be emitted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
json_data = json.loads(self.format(record))
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""
|
||||
Format the given record and return the message.
|
||||
|
||||
Args:
|
||||
record (object): The log record to be formatted.
|
||||
|
||||
Returns:
|
||||
str: The formatted message from the record.
|
||||
"""
|
||||
return record.msg
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
Logger that handle titles in different colors.
|
||||
Outputs logs in console, activity.log, and errors.log
|
||||
For console handler: simulates typing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the class and sets up the logging configuration.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# create log directory if it doesn't exist
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
log_file = "activity.log"
|
||||
error_file = "error.log"
|
||||
|
||||
console_formatter = AutoGptFormatter("%(title_color)s %(message)s")
|
||||
|
||||
# Create a handler for console which simulate typing
|
||||
self.typing_console_handler = TypingConsoleHandler()
|
||||
# self.typing_console_handler = ConsoleHandler()
|
||||
self.typing_console_handler.setLevel(logging.INFO)
|
||||
self.typing_console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Create a handler for console without typing simulation
|
||||
self.console_handler = ConsoleHandler()
|
||||
self.console_handler.setLevel(logging.DEBUG)
|
||||
self.console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Info handler in activity.log
|
||||
self.file_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, log_file), "a", "utf-8"
|
||||
)
|
||||
self.file_handler.setLevel(logging.DEBUG)
|
||||
info_formatter = AutoGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(title)s %(message_no_color)s"
|
||||
)
|
||||
self.file_handler.setFormatter(info_formatter)
|
||||
|
||||
# Error handler error.log
|
||||
error_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, error_file), "a", "utf-8"
|
||||
)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
error_formatter = AutoGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s"
|
||||
" %(message_no_color)s"
|
||||
)
|
||||
error_handler.setFormatter(error_formatter)
|
||||
|
||||
self.typing_logger = logging.getLogger("TYPER")
|
||||
self.typing_logger.addHandler(self.typing_console_handler)
|
||||
# self.typing_logger.addHandler(self.console_handler)
|
||||
self.typing_logger.addHandler(self.file_handler)
|
||||
self.typing_logger.addHandler(error_handler)
|
||||
self.typing_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.logger = logging.getLogger("LOGGER")
|
||||
self.logger.addHandler(self.console_handler)
|
||||
self.logger.addHandler(self.file_handler)
|
||||
self.logger.addHandler(error_handler)
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.json_logger = logging.getLogger("JSON_LOGGER")
|
||||
self.json_logger.addHandler(self.file_handler)
|
||||
self.json_logger.addHandler(error_handler)
|
||||
self.json_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.speak_mode = False
|
||||
self.chat_plugins = []
|
||||
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
"""
|
||||
Logs a message to the typewriter.
|
||||
|
||||
Args:
|
||||
title (str, optional): The title of the log message. Defaults to "".
|
||||
title_color (str, optional): The color of the title. Defaults to "".
|
||||
content (str or list, optional): The content of the log message. Defaults to "".
|
||||
speak_text (bool, optional): Whether to speak the log message. Defaults to False.
|
||||
level (int, optional): The logging level of the message. Defaults to logging.INFO.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.chat_plugins:
|
||||
plugin.report(f"{title}. {content}")
|
||||
|
||||
if content:
|
||||
if isinstance(content, list):
|
||||
content = " ".join(content)
|
||||
else:
|
||||
content = ""
|
||||
|
||||
self.typing_logger.log(
|
||||
level, content, extra={"title": title, "color": title_color}
|
||||
)
|
||||
|
||||
def debug(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
"""
|
||||
Logs a debug message.
|
||||
|
||||
Args:
|
||||
message (str): The debug message to log.
|
||||
title (str, optional): The title of the log message. Defaults to "".
|
||||
title_color (str, optional): The color of the log message title. Defaults to "".
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._log(title, title_color, message, logging.DEBUG)
|
||||
|
||||
def info(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
"""
|
||||
Logs an informational message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
title (str, optional): The title of the log message. Defaults to "".
|
||||
title_color (str, optional): The color of the log title. Defaults to "".
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._log(title, title_color, message, logging.INFO)
|
||||
|
||||
def warn(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
"""
|
||||
Logs a warning message.
|
||||
|
||||
Args:
|
||||
message (str): The warning message.
|
||||
title (str, optional): The title of the warning message. Defaults to "".
|
||||
title_color (str, optional): The color of the title. Defaults to "".
|
||||
"""
|
||||
self._log(title, title_color, message, logging.WARN)
|
||||
|
||||
def error(self, title, message=""):
|
||||
"""
|
||||
Logs an error message with the given title and optional message.
|
||||
|
||||
Parameters:
|
||||
title (str): The title of the error message.
|
||||
message (str, optional): The optional additional message for the error. Defaults to an empty string.
|
||||
"""
|
||||
self._log(title, Fore.RED, message, logging.ERROR)
|
||||
|
||||
def _log(
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
):
|
||||
"""
|
||||
Logs a message with the given title and message at the specified log level.
|
||||
|
||||
Parameters:
|
||||
title (str): The title of the log message. Defaults to an empty string.
|
||||
title_color (str): The color of the log message title. Defaults to an empty string.
|
||||
message (str): The log message. Defaults to an empty string.
|
||||
level (int): The log level. Defaults to logging.INFO.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if message:
|
||||
if isinstance(message, list):
|
||||
message = " ".join(message)
|
||||
self.logger.log(
|
||||
level, message, extra={"title": str(title), "color": str(title_color)}
|
||||
)
|
||||
|
||||
def set_level(self, level):
|
||||
"""
|
||||
Set the level of the logger and the typing_logger.
|
||||
|
||||
Args:
|
||||
level: The level to set the logger to.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.logger.setLevel(level)
|
||||
self.typing_logger.setLevel(level)
|
||||
|
||||
def double_check(self, additionalText=None):
|
||||
"""
|
||||
A function that performs a double check on the configuration.
|
||||
|
||||
Parameters:
|
||||
additionalText (str, optional): Additional text to be included in the double check message.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not additionalText:
|
||||
additionalText = (
|
||||
"Please ensure you've setup and configured everything"
|
||||
" correctly. Read https://github.com/Torantulino/Auto-GPT#readme to "
|
||||
"double check. You can also create a github issue or join the discord"
|
||||
" and ask there!"
|
||||
)
|
||||
|
||||
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
||||
|
||||
def log_json(self, data: Any, file_name: str) -> None:
|
||||
"""
|
||||
Logs the given JSON data to a specified file.
|
||||
|
||||
Args:
|
||||
data (Any): The JSON data to be logged.
|
||||
file_name (str): The name of the file to log the data to.
|
||||
|
||||
Returns:
|
||||
None: This function does not return anything.
|
||||
"""
|
||||
# Define log directory
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
|
||||
# Create a handler for JSON files
|
||||
json_file_path = os.path.join(log_dir, file_name)
|
||||
json_data_handler = JsonFileHandler(json_file_path)
|
||||
json_data_handler.setFormatter(JsonFormatter())
|
||||
|
||||
# Log the JSON data using the custom file handler
|
||||
self.json_logger.addHandler(json_data_handler)
|
||||
self.json_logger.debug(data)
|
||||
self.json_logger.removeHandler(json_data_handler)
|
||||
|
||||
def get_log_directory(self):
|
||||
"""
|
||||
Returns the absolute path to the log directory.
|
||||
|
||||
Returns:
|
||||
str: The absolute path to the log directory.
|
||||
"""
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
return os.path.abspath(log_dir)
|
||||
|
||||
|
||||
"""
|
||||
Output stream to console using simulated typing
|
||||
"""
|
||||
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record):
|
||||
"""
|
||||
Emit a log record to the console with simulated typing effect.
|
||||
|
||||
Args:
|
||||
record (LogRecord): The log record to be emitted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs while emitting the log record.
|
||||
"""
|
||||
min_typing_speed = 0.05
|
||||
max_typing_speed = 0.10
|
||||
# min_typing_speed = 0.005
|
||||
# max_typing_speed = 0.010
|
||||
|
||||
msg = self.format(record)
|
||||
try:
|
||||
# replace enter & indent with other symbols
|
||||
transfer_enter = "<ENTER>"
|
||||
msg_transfered = str(msg).replace("\n", transfer_enter)
|
||||
transfer_space = "<4SPACE>"
|
||||
msg_transfered = str(msg_transfered).replace(" ", transfer_space)
|
||||
words = msg_transfered.split()
|
||||
words = [word.replace(transfer_enter, "\n") for word in words]
|
||||
words = [word.replace(transfer_space, " ") for word in words]
|
||||
|
||||
for i, word in enumerate(words):
|
||||
print(word, end="", flush=True)
|
||||
if i < len(words) - 1:
|
||||
print(" ", end="", flush=True)
|
||||
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
|
||||
time.sleep(typing_speed)
|
||||
# type faster after each word
|
||||
min_typing_speed = min_typing_speed * 0.95
|
||||
max_typing_speed = max_typing_speed * 0.95
|
||||
print()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record) -> None:
|
||||
"""
|
||||
Emit the log record.
|
||||
|
||||
Args:
|
||||
record (logging.LogRecord): The log record to emit.
|
||||
|
||||
Returns:
|
||||
None: This function does not return anything.
|
||||
"""
|
||||
msg = self.format(record)
|
||||
try:
|
||||
print(msg)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class AutoGptFormatter(logging.Formatter):
|
||||
"""
|
||||
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
|
||||
To use this formatter, make sure to pass 'color', 'title' as log extras.
|
||||
"""
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""
|
||||
Formats a log record into a string representation.
|
||||
|
||||
Args:
|
||||
record (LogRecord): The log record to be formatted.
|
||||
|
||||
Returns:
|
||||
str: The formatted log record as a string.
|
||||
"""
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
|
||||
# Add this line to set 'title' to an empty string if it doesn't exist
|
||||
record.title = getattr(record, "title", "")
|
||||
|
||||
if hasattr(record, "msg"):
|
||||
record.message_no_color = remove_color_codes(getattr(record, "msg"))
|
||||
else:
|
||||
record.message_no_color = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
"""
|
||||
Removes color codes from a given string.
|
||||
|
||||
Args:
|
||||
s (str): The string from which to remove color codes.
|
||||
|
||||
Returns:
|
||||
str: The string with color codes removed.
|
||||
"""
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", s)
|
||||
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
def print_action_base(action: Action):
|
||||
"""
|
||||
Print the different properties of an Action object.
|
||||
|
||||
Parameters:
|
||||
action (Action): The Action object to print.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if action.content != "":
|
||||
logger.typewriter_log(f"content:", Fore.YELLOW, f"{action.content}")
|
||||
logger.typewriter_log(f"Thought:", Fore.YELLOW, f"{action.thought}")
|
||||
if len(action.plan) > 0:
|
||||
logger.typewriter_log(
|
||||
f"Plan:",
|
||||
Fore.YELLOW,
|
||||
)
|
||||
for line in action.plan:
|
||||
line = line.lstrip("- ")
|
||||
logger.typewriter_log("- ", Fore.GREEN, line.strip())
|
||||
logger.typewriter_log(f"Criticism:", Fore.YELLOW, f"{action.criticism}")
|
||||
|
||||
|
||||
def print_action_tool(action: Action):
|
||||
"""
|
||||
Prints the details of an action tool.
|
||||
|
||||
Args:
|
||||
action (Action): The action object containing the tool details.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}")
|
||||
logger.typewriter_log(f"Tool Input:", Fore.BLUE, f"{action.tool_input}")
|
||||
|
||||
output = action.tool_output if action.tool_output != "" else "None"
|
||||
logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}")
|
||||
|
||||
color = Fore.RED
|
||||
if action.tool_output_status == ToolCallStatus.ToolCallSuccess:
|
||||
color = Fore.GREEN
|
||||
elif action.tool_output_status == ToolCallStatus.InputCannotParsed:
|
||||
color = Fore.YELLOW
|
||||
|
||||
logger.typewriter_log(
|
||||
f"Tool Call Status:",
|
||||
Fore.BLUE,
|
||||
f"{color}{action.tool_output_status.name}{Style.RESET_ALL}",
|
||||
)
|
@ -0,0 +1,44 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
try:
|
||||
import PyPDF2
|
||||
except ImportError:
|
||||
print("PyPDF2 not installed. Please install it using: pip install PyPDF2")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def pdf_to_text(pdf_path):
|
||||
"""
|
||||
Converts a PDF file to a string of text.
|
||||
|
||||
Args:
|
||||
pdf_path (str): The path to the PDF file to be converted.
|
||||
|
||||
Returns:
|
||||
str: The text extracted from the PDF.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file is not found at the specified path.
|
||||
Exception: If there is an error in reading the PDF file.
|
||||
"""
|
||||
try:
|
||||
# Open the PDF file
|
||||
with open(pdf_path, "rb") as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
text = ""
|
||||
|
||||
# Iterate through each page and extract text
|
||||
for page in pdf_reader.pages:
|
||||
text += page.extract_text() + "\n"
|
||||
|
||||
return text
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"The file at {pdf_path} was not found.")
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while reading the PDF file: {e}")
|
||||
|
||||
|
||||
# Example usage
|
||||
# text = pdf_to_text("test.pdf")
|
||||
# print(text)
|
@ -1,29 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# from env import DotEnv
|
||||
|
||||
from swarms.utils.main import AbstractUploader
|
||||
|
||||
|
||||
class StaticUploader(AbstractUploader):
|
||||
def __init__(self, server: str, path: Path, endpoint: str):
|
||||
self.server = server
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
|
||||
@staticmethod
|
||||
def from_settings(path: Path, endpoint: str) -> "StaticUploader":
|
||||
return StaticUploader(os.environ["SERVER"], path, endpoint)
|
||||
|
||||
def get_url(self, uploaded_path: str) -> str:
|
||||
return f"{self.server}/{uploaded_path}"
|
||||
|
||||
def upload(self, filepath: str):
|
||||
relative_path = Path("generated") / filepath.split("/")[-1]
|
||||
file_path = self.path / relative_path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
shutil.copy(filepath, file_path)
|
||||
endpoint_path = self.endpoint / relative_path
|
||||
return f"{self.server}/{endpoint_path}"
|
@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, mock_open
|
||||
from swarms.models.eleven_labs import ElevenLabsText2SpeechTool, ElevenLabsModel
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Define some test data
|
||||
SAMPLE_TEXT = "Hello, this is a test."
|
||||
API_KEY = os.environ.get("ELEVEN_API_KEY")
|
||||
EXPECTED_SPEECH_FILE = "expected_speech.wav"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eleven_labs_tool():
|
||||
return ElevenLabsText2SpeechTool()
|
||||
|
||||
|
||||
# Basic functionality tests
|
||||
def test_run_text_to_speech(eleven_labs_tool):
|
||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert isinstance(speech_file, str)
|
||||
assert speech_file.endswith(".wav")
|
||||
|
||||
|
||||
def test_play_speech(eleven_labs_tool):
|
||||
with patch("builtins.open", mock_open(read_data="fake_audio_data")):
|
||||
eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
|
||||
|
||||
|
||||
def test_stream_speech(eleven_labs_tool):
|
||||
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file:
|
||||
eleven_labs_tool.stream_speech(SAMPLE_TEXT)
|
||||
mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False)
|
||||
|
||||
|
||||
# Testing fixture and environment variables
|
||||
def test_api_key_validation(eleven_labs_tool):
|
||||
with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY):
|
||||
values = {"eleven_api_key": None}
|
||||
validated_values = eleven_labs_tool.validate_environment(values)
|
||||
assert "eleven_api_key" in validated_values
|
||||
|
||||
|
||||
# Mocking the external library
|
||||
def test_run_text_to_speech_with_mock(eleven_labs_tool):
|
||||
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch(
|
||||
"your_module._import_elevenlabs"
|
||||
) as mock_elevenlabs:
|
||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
||||
mock_elevenlabs_instance.generate.return_value = b"fake_audio_data"
|
||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert mock_file.call_args[1]["suffix"] == ".wav"
|
||||
assert mock_file.call_args[1]["delete"] is False
|
||||
assert mock_file().write.call_args[0][0] == b"fake_audio_data"
|
||||
|
||||
|
||||
# Exception testing
|
||||
def test_run_text_to_speech_error_handling(eleven_labs_tool):
|
||||
with patch("your_module._import_elevenlabs") as mock_elevenlabs:
|
||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
||||
mock_elevenlabs_instance.generate.side_effect = Exception("Test Exception")
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Error while running ElevenLabsText2SpeechTool: Test Exception",
|
||||
):
|
||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
|
||||
|
||||
# Parameterized testing
|
||||
@pytest.mark.parametrize(
|
||||
"model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL]
|
||||
)
|
||||
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model):
|
||||
eleven_labs_tool.model = model
|
||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert isinstance(speech_file, str)
|
||||
assert speech_file.endswith(".wav")
|
@ -0,0 +1,26 @@
|
||||
from swarms.models import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"Anthropic",
|
||||
"Petals",
|
||||
"Mistral",
|
||||
"OpenAI",
|
||||
"AzureOpenAI",
|
||||
"OpenAIChat",
|
||||
"Zephyr",
|
||||
"Idefics",
|
||||
# "Kosmos",
|
||||
"Vilt",
|
||||
"Nougat",
|
||||
"LayoutLMDocumentQA",
|
||||
"BioGPT",
|
||||
"HuggingfaceLLM",
|
||||
"MPT7B",
|
||||
"WizardLLMStoryTeller",
|
||||
# "GPT4Vision",
|
||||
# "Dalle3",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -1,76 +1,168 @@
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.structs import Flow
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.swarms.multi_agent_collab import (
|
||||
MultiAgentCollaboration,
|
||||
Worker,
|
||||
select_next_speaker,
|
||||
select_next_speaker_director,
|
||||
select_speaker_round_table,
|
||||
)
|
||||
|
||||
# Sample agents for testing
|
||||
agent1 = Flow(llm=OpenAIChat(), max_loops=2)
|
||||
agent2 = Flow(llm=OpenAIChat(), max_loops=2)
|
||||
agents = [agent1, agent2]
|
||||
|
||||
def test_multiagentcollaboration_initialization():
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
assert isinstance(multiagentcollaboration, MultiAgentCollaboration)
|
||||
assert len(multiagentcollaboration.agents) == 5
|
||||
assert multiagentcollaboration._step == 0
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.reset")
|
||||
def test_multiagentcollaboration_reset(mock_reset):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.reset()
|
||||
assert mock_reset.call_count == 5
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.run")
|
||||
def test_multiagentcollaboration_inject(mock_run):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.inject("Agent 1", "Hello, world!")
|
||||
assert multiagentcollaboration._step == 1
|
||||
assert mock_run.call_count == 5
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.send")
|
||||
@patch("swarms.workers.Worker.receive")
|
||||
def test_multiagentcollaboration_step(mock_receive, mock_send):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.step()
|
||||
assert multiagentcollaboration._step == 1
|
||||
assert mock_send.call_count == 5
|
||||
assert mock_receive.call_count == 25
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.bid")
|
||||
def test_multiagentcollaboration_ask_for_bid(mock_bid):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
result = multiagentcollaboration.ask_for_bid(Worker)
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.bid")
|
||||
def test_multiagentcollaboration_select_next_speaker(mock_bid):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
result = multiagentcollaboration.select_next_speaker(1, [Worker] * 5)
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.send")
|
||||
@patch("swarms.workers.Worker.receive")
|
||||
def test_multiagentcollaboration_run(mock_receive, mock_send):
|
||||
multiagentcollaboration = MultiAgentCollaboration(
|
||||
agents=[Worker] * 5, selection_function=select_next_speaker
|
||||
)
|
||||
multiagentcollaboration.run(max_iters=5)
|
||||
assert multiagentcollaboration._step == 6
|
||||
assert mock_send.call_count == 30
|
||||
assert mock_receive.call_count == 150
|
||||
|
||||
@pytest.fixture
|
||||
def collaboration():
|
||||
return MultiAgentCollaboration(agents)
|
||||
|
||||
|
||||
def test_collaboration_initialization(collaboration):
|
||||
assert len(collaboration.agents) == 2
|
||||
assert callable(collaboration.select_next_speaker)
|
||||
assert collaboration.max_iters == 10
|
||||
assert collaboration.results == []
|
||||
assert collaboration.logging == True
|
||||
|
||||
|
||||
def test_reset(collaboration):
|
||||
collaboration.reset()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.step == 0
|
||||
|
||||
|
||||
def test_inject(collaboration):
|
||||
collaboration.inject("TestName", "TestMessage")
|
||||
for agent in collaboration.agents:
|
||||
assert "TestName" in agent.history[-1]
|
||||
assert "TestMessage" in agent.history[-1]
|
||||
|
||||
|
||||
def test_inject_agent(collaboration):
|
||||
agent3 = Flow(llm=OpenAIChat(), max_loops=2)
|
||||
collaboration.inject_agent(agent3)
|
||||
assert len(collaboration.agents) == 3
|
||||
assert agent3 in collaboration.agents
|
||||
|
||||
|
||||
def test_step(collaboration):
|
||||
collaboration.step()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.step == 1
|
||||
|
||||
|
||||
def test_ask_for_bid(collaboration):
|
||||
agent = Mock()
|
||||
agent.bid.return_value = "<5>"
|
||||
bid = collaboration.ask_for_bid(agent)
|
||||
assert bid == 5
|
||||
|
||||
|
||||
def test_select_next_speaker(collaboration):
|
||||
collaboration.select_next_speaker = Mock(return_value=0)
|
||||
idx = collaboration.select_next_speaker(1, collaboration.agents)
|
||||
assert idx == 0
|
||||
|
||||
|
||||
def test_run(collaboration):
|
||||
collaboration.run()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.step == collaboration.max_iters
|
||||
|
||||
|
||||
def test_format_results(collaboration):
|
||||
collaboration.results = [{"agent": "Agent1", "response": "Response1"}]
|
||||
formatted_results = collaboration.format_results(collaboration.results)
|
||||
assert "Agent1 responded: Response1" in formatted_results
|
||||
|
||||
|
||||
def test_save_and_load(collaboration):
|
||||
collaboration.save()
|
||||
loaded_state = collaboration.load()
|
||||
assert loaded_state["_step"] == collaboration._step
|
||||
assert loaded_state["results"] == collaboration.results
|
||||
|
||||
|
||||
def test_performance(collaboration):
|
||||
performance_data = collaboration.performance()
|
||||
for agent in collaboration.agents:
|
||||
assert agent.name in performance_data
|
||||
assert "metrics" in performance_data[agent.name]
|
||||
|
||||
|
||||
def test_set_interaction_rules(collaboration):
|
||||
rules = {"rule1": "action1", "rule2": "action2"}
|
||||
collaboration.set_interaction_rules(rules)
|
||||
assert hasattr(collaboration, "interaction_rules")
|
||||
assert collaboration.interaction_rules == rules
|
||||
|
||||
|
||||
def test_set_interaction_rules(collaboration):
|
||||
rules = {"rule1": "action1", "rule2": "action2"}
|
||||
collaboration.set_interaction_rules(rules)
|
||||
assert hasattr(collaboration, "interaction_rules")
|
||||
assert collaboration.interaction_rules == rules
|
||||
|
||||
|
||||
def test_repr(collaboration):
|
||||
repr_str = repr(collaboration)
|
||||
assert isinstance(repr_str, str)
|
||||
assert "MultiAgentCollaboration" in repr_str
|
||||
|
||||
|
||||
def test_load(collaboration):
|
||||
state = {"step": 5, "results": [{"agent": "Agent1", "response": "Response1"}]}
|
||||
with open(collaboration.saved_file_path_name, "w") as file:
|
||||
json.dump(state, file)
|
||||
|
||||
loaded_state = collaboration.load()
|
||||
assert loaded_state["_step"] == state["step"]
|
||||
assert loaded_state["results"] == state["results"]
|
||||
|
||||
|
||||
def test_save(collaboration, tmp_path):
|
||||
collaboration.saved_file_path_name = tmp_path / "test_save.json"
|
||||
collaboration.save()
|
||||
|
||||
with open(collaboration.saved_file_path_name, "r") as file:
|
||||
saved_data = json.load(file)
|
||||
|
||||
assert saved_data["_step"] == collaboration._step
|
||||
assert saved_data["results"] == collaboration.results
|
||||
|
||||
|
||||
# Add more tests here...
|
||||
|
||||
|
||||
# Example of parameterized test for different selection functions
|
||||
@pytest.mark.parametrize(
|
||||
"selection_function", [select_next_speaker_director, select_speaker_round_table]
|
||||
)
|
||||
def test_selection_functions(collaboration, selection_function):
|
||||
collaboration.select_next_speaker = selection_function
|
||||
assert callable(collaboration.select_next_speaker)
|
||||
|
||||
|
||||
# Add more parameterized tests for different scenarios...
|
||||
|
||||
|
||||
# Example of exception testing
|
||||
def test_exception_handling(collaboration):
|
||||
agent = Mock()
|
||||
agent.bid.side_effect = ValueError("Invalid bid")
|
||||
with pytest.raises(ValueError):
|
||||
collaboration.ask_for_bid(agent)
|
||||
|
||||
|
||||
# Add more exception testing...
|
||||
|
||||
|
||||
# Example of environment variable testing (if applicable)
|
||||
@pytest.mark.parametrize("env_var", ["ENV_VAR_1", "ENV_VAR_2"])
|
||||
def test_environment_variables(collaboration, monkeypatch, env_var):
|
||||
monkeypatch.setenv(env_var, "test_value")
|
||||
assert os.getenv(env_var) == "test_value"
|
||||
|
Loading…
Reference in new issue