diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 21129735..be346103 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -22,4 +22,4 @@ jobs: - run: ruff format . - run: ruff check --fix . - - uses: autofix-ci/action@dd55f44df8f7cdb7a6bf74c78677eb8acd40cd0a + - uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c diff --git a/.github/workflows/bearer.yml b/.github/workflows/bearer.yml new file mode 100644 index 00000000..be0fb591 --- /dev/null +++ b/.github/workflows/bearer.yml @@ -0,0 +1,43 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. +# +# This workflow file requires a free account on Bearer.com to manage findings, notifications and more. +# See https://docs.bearer.com/guides/bearer-cloud/ +name: Bearer + +on: + push: + branches: ["master" ] + pull_request: + # The branches below must be a subset of the branches above + branches: ["master"] + schedule: + - cron: '24 22 * * 6' + +permissions: + contents: read # for actions/checkout to fetch code + security-events: write # for github/codeql-action/upload-sarif to upload SARIF results + actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status + +jobs: + bearer: + runs-on: ubuntu-latest + steps: + # Checkout project source + - uses: actions/checkout@v4 + # Scan code using Bearer CLI + - name: Run Report + id: report + uses: bearer/bearer-action@828eeb928ce2f4a7ca5ed57fb8b59508cb8c79bc + with: + api-key: ${{ secrets.BEARER_TOKEN }} + format: sarif + output: results.sarif + exit-code: 0 + # Upload SARIF file generated in previous step + - name: Upload SARIF file + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: results.sarif diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml new file mode 100644 index 00000000..9bbf3ba2 --- /dev/null +++ b/.github/workflows/dependency-review.yml @@ -0,0 +1,39 @@ +# Dependency Review Action +# +# This Action will scan dependency manifest files that change as part of a Pull Request, +# surfacing known-vulnerable versions of the packages declared or updated in the PR. +# Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable +# packages will be blocked from merging. +# +# Source repository: https://github.com/actions/dependency-review-action +# Public documentation: https://docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/about-dependency-review#dependency-review-enforcement +name: 'Dependency review' +on: + pull_request: + branches: [ "master" ] + +# If using a dependency submission action in this workflow this permission will need to be set to: +# +# permissions: +# contents: write +# +# https://docs.github.com/en/enterprise-cloud@latest/code-security/supply-chain-security/understanding-your-software-supply-chain/using-the-dependency-submission-api +permissions: + contents: read + # Write permissions for pull-requests are required for using the `comment-summary-in-pr` option, comment out if you aren't using this option + pull-requests: write + +jobs: + dependency-review: + runs-on: ubuntu-latest + steps: + - name: 'Checkout repository' + uses: actions/checkout@v4 + - name: 'Dependency Review' + uses: actions/dependency-review-action@v4 + # Commonly enabled options, see https://github.com/actions/dependency-review-action#configuration-options for all available options. + with: + comment-summary-in-pr: always + # fail-on-severity: moderate + # deny-licenses: GPL-1.0-or-later, LGPL-2.0-or-later + # retry-on-snapshot-warnings: true diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml new file mode 100644 index 00000000..793d8e0e --- /dev/null +++ b/.github/workflows/docker-image.yml @@ -0,0 +1,18 @@ +name: Docker Image CI + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Build the Docker image + run: docker build . --file Dockerfile --tag my-image-name:$(date +%s) diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml new file mode 100644 index 00000000..2e4713d3 --- /dev/null +++ b/.github/workflows/pyre.yml @@ -0,0 +1,46 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# This workflow integrates Pyre with GitHub's +# Code Scanning feature. +# +# Pyre is a performant type checker for Python compliant with +# PEP 484. Pyre can analyze codebases with millions of lines +# of code incrementally – providing instantaneous feedback +# to developers as they write code. +# +# See https://pyre-check.org + +name: Pyre + +on: + workflow_dispatch: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +permissions: + contents: read + +jobs: + pyre: + permissions: + actions: read + contents: read + security-events: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - name: Run Pyre + uses: facebook/pyre-action@60697a7858f7cc8470d8cc494a3cf2ad6b06560d + with: + # To customize these inputs: + # See https://github.com/facebook/pyre-action#inputs + repo-directory: './' + requirements-path: 'requirements.txt' diff --git a/.github/workflows/pysa.yml b/.github/workflows/pysa.yml new file mode 100644 index 00000000..6c301e80 --- /dev/null +++ b/.github/workflows/pysa.yml @@ -0,0 +1,50 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# This workflow integrates Python Static Analyzer (Pysa) with +# GitHub's Code Scanning feature. +# +# Python Static Analyzer (Pysa) is a security-focused static +# analysis tool that tracks flows of data from where they +# originate to where they terminate in a dangerous location. +# +# See https://pyre-check.org/docs/pysa-basics/ + +name: Pysa + +on: + workflow_dispatch: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + schedule: + - cron: '43 5 * * 3' + +permissions: + contents: read + +jobs: + pysa: + permissions: + actions: read + contents: read + security-events: write + + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - name: Run Pysa + uses: facebook/pysa-action@f46a63777e59268613bd6e2ff4e29f144ca9e88b + with: + # To customize these inputs: + # See https://github.com/facebook/pysa-action#inputs + repo-directory: './' + requirements-path: 'requirements.txt' + infer-types: true + include-default-sapp-filters: true diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml new file mode 100644 index 00000000..f3586044 --- /dev/null +++ b/.github/workflows/python-package-conda.yml @@ -0,0 +1,34 @@ +name: Python Package using Conda + +on: [push] + +jobs: + build-linux: + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Install dependencies + run: | + conda env update --file environment.yml --name base + - name: Lint with flake8 + run: | + conda install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + conda install pytest + pytest diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml new file mode 100644 index 00000000..1e78a687 --- /dev/null +++ b/.github/workflows/semgrep.yml @@ -0,0 +1,49 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# This workflow file requires a free account on Semgrep.dev to +# manage rules, file ignores, notifications, and more. +# +# See https://semgrep.dev/docs + +name: Semgrep + +on: + push: + branches: [ "master" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "master" ] + schedule: + - cron: '19 7 * * 3' + +permissions: + contents: read + +jobs: + semgrep: + permissions: + contents: read # for actions/checkout to fetch code + security-events: write # for github/codeql-action/upload-sarif to upload SARIF results + actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status + name: Scan + runs-on: ubuntu-latest + steps: + # Checkout project source + - uses: actions/checkout@v4 + + # Scan code using project's configuration on https://semgrep.dev/manage + - uses: returntocorp/semgrep-action@fcd5ab7459e8d91cb1777481980d1b18b4fc6735 + with: + publishToken: ${{ secrets.SEMGREP_APP_TOKEN }} + publishDeployment: ${{ secrets.SEMGREP_DEPLOYMENT_ID }} + generateSarif: "1" + + # Upload SARIF file generated in previous step + - name: Upload SARIF file + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: semgrep.sarif + if: always() diff --git a/.github/workflows/trivy.yml b/.github/workflows/trivy.yml new file mode 100644 index 00000000..d9e6c82b --- /dev/null +++ b/.github/workflows/trivy.yml @@ -0,0 +1,48 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: trivy + +on: + push: + branches: [ "master" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "master" ] + schedule: + - cron: '31 0 * * 5' + +permissions: + contents: read + +jobs: + build: + permissions: + contents: read # for actions/checkout to fetch code + security-events: write # for github/codeql-action/upload-sarif to upload SARIF results + actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status + name: Build + runs-on: "ubuntu-20.04" + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Build an image from Dockerfile + run: | + docker build -t docker.io/my-organization/my-app:${{ github.sha }} . + + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@7b7aa264d83dc58691451798b4d117d53d21edfe + with: + image-ref: 'docker.io/my-organization/my-app:${{ github.sha }}' + format: 'template' + template: '@/contrib/sarif.tpl' + output: 'trivy-results.sarif' + severity: 'CRITICAL,HIGH' + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: 'trivy-results.sarif' diff --git a/.gitignore b/.gitignore index 89b0cdc7..9f6e25b6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,8 @@ audio/ video/ artifacts_three dataframe/ - +.ruff_cache +.pytest_cache static/generated runs Financial-Analysis-Agent_state.json diff --git a/README.md b/README.md index f98f43ea..071b1991 100644 --- a/README.md +++ b/README.md @@ -81,9 +81,10 @@ Refer to our documentation for production grade implementation details. ## Install 💻 +Install the following packages with copy and paste ```bash -$ pip3 install -U swarms +$ pip3 install -U swarms swarm-models swarms-memory ``` @@ -113,14 +114,36 @@ Here are some example scripts to get you started. For more comprehensive documen | Swarms Examples | A collection of simple examples to demonstrate Swarms capabilities. | Basic Usage | [https://github.com/The-Swarm-Corporation/swarms-examples?tab=readme-ov-file](https://github.com/The-Swarm-Corporation/swarms-examples?tab=readme-ov-file) | | Cookbook | A comprehensive guide with recipes for various use cases and scenarios. | Advanced Usage | [https://github.com/The-Swarm-Corporation/Cookbook](https://github.com/The-Swarm-Corporation/Cookbook) | + + + --- ## `Agent` Class The `Agent` class is a fundamental component of the Swarms framework, designed to execute tasks autonomously. It fuses llms, tools and long-term memory capabilities to create a full stack agent. The `Agent` class is highly customizable, allowing for fine-grained control over its behavior and interactions. + ### `run` Method -The `run` method is the primary entry point for executing tasks with an `Agent` instance. It accepts a task string as the main input task and processes it according to the agent's configuration. And, it can also accept an `img` parameter such as `img="image_filepath.png` to process images if you have a VLM +The `run` method is the primary entry point for executing tasks with an `Agent` instance. It accepts a task string as the main input task and processes it according to the agent's configuration. And, it can also accept an `img` parameter such as `img="image_filepath.png` to process images if you have a VLM attached such as `GPT4VisionAPI` + + + +## Simple Example + +```python +from swarms import Agent + +agent = Agent( + agent_name="Stock-Analysis-Agent", + model_name="gpt-4o-mini", + max_loops="auto", + interactive=True, + streaming_on=True, +) + +agent.run("What is the current market trend for tech stocks?") +``` ### Settings and Customization The `Agent` class offers a range of settings to tailor its behavior to specific needs. Some key settings include: @@ -146,28 +169,15 @@ The `Agent` class offers a range of settings to tailor its behavior to specific ```python import os from swarms import Agent -from swarm_models import OpenAIChat from swarms.prompts.finance_agent_sys_prompt import ( FINANCIAL_AGENT_SYS_PROMPT, ) -from dotenv import load_dotenv - -load_dotenv() - -# Get the OpenAI API key from the environment variable -api_key = os.getenv("OPENAI_API_KEY") - -# Create an instance of the OpenAIChat class -model = OpenAIChat( - openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 -) - # Initialize the agent agent = Agent( agent_name="Financial-Analysis-Agent", system_prompt=FINANCIAL_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o-mini", max_loops=1, autosave=True, dashboard=False, @@ -189,11 +199,10 @@ agent.run( ``` ----- + ### Integrating RAG with Swarms for Enhanced Long-Term Memory `Agent` equipped with quasi-infinite long term memory using RAG (Relational Agent Graph) for advanced document understanding, analysis, and retrieval capabilities. - - **Mermaid Diagram for RAG Integration** ```mermaid graph TD @@ -205,8 +214,11 @@ graph TD F --> G[Return Output] ``` -**Step 1: Initialize the ChromaDB Client** ```python +from swarms import Agent +from swarms.prompts.finance_agent_sys_prompt import ( + FINANCIAL_AGENT_SYS_PROMPT, +) import os from swarms_memory import ChromaDB @@ -217,29 +229,13 @@ chromadb = ChromaDB( output_dir="finance_agent_rag", # Directory for storing RAG data # docs_folder="artifacts", # Uncomment and specify the folder containing your documents ) -``` - -**Step 2: Define the Model** -```python -from swarm_models import Anthropic -from swarms.prompts.finance_agent_sys_prompt import ( - FINANCIAL_AGENT_SYS_PROMPT, -) - -# Define the Anthropic model for language processing -model = Anthropic(anthropic_api_key=os.getenv("ANTHROPIC_API_KEY")) -``` - -**Step 3: Initialize the Agent with RAG** -```python -from swarms import Agent # Initialize the agent with RAG capabilities agent = Agent( agent_name="Financial-Analysis-Agent", system_prompt=FINANCIAL_AGENT_SYS_PROMPT, agent_description="Agent creates a comprehensive financial analysis", - llm=model, + model_name="gpt-4o-mini", max_loops="auto", # Auto-adjusts loops based on task complexity autosave=True, # Automatically saves agent state dashboard=False, # Disables dashboard for this example @@ -356,7 +352,6 @@ The following is an example of an agent that intakes a pydantic basemodel and ou ```python from pydantic import BaseModel, Field from swarms import Agent -from swarm_models import Anthropic # Initialize the schema for the person's information @@ -388,7 +383,7 @@ agent = Agent( ), # Set the tool schema to the JSON string -- this is the key difference tool_schema=tool_schema, - llm=Anthropic(), + model_name="gpt-4o", max_loops=3, autosave=True, dashboard=False, @@ -467,7 +462,7 @@ from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent -from swarms.utils.json_utils import base_model_to_json +from swarms.tools.json_utils import base_model_to_json # Load the pre-trained model and tokenizer model = AutoModelForCausalLM.from_pretrained( @@ -516,87 +511,8 @@ print(f"Generated data: {generated_data}") ``` -## Integrating External Agents -Integrating external agents from other agent frameworks is easy with swarms. - -Steps: - -1. Create a new class that inherits `Agent` -2. Create a `.run(task: str) -> str` method that runs the agent and returns the response. -3. The new Agent must return a string of the response. But you may add additional methods to save the output to JSON. - - -### Griptape Example - -For example, here's an example on how to create an agent from griptape. - -Here’s how you can create a custom **Griptape** agent that integrates with the **Swarms** framework by inheriting from the `Agent` class in **Swarms** and overriding the `run(task: str) -> str` method. - - -```python -from swarms import ( - Agent as SwarmsAgent, -) # Import the base Agent class from Swarms -from griptape.structures import Agent as GriptapeAgent -from griptape.tools import ( - WebScraperTool, - FileManagerTool, - PromptSummaryTool, -) - - -# Create a custom agent class that inherits from SwarmsAgent -class GriptapeSwarmsAgent(SwarmsAgent): - def __init__(self, *args, **kwargs): - # Initialize the Griptape agent with its tools - self.agent = GriptapeAgent( - input="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.", - tools=[ - WebScraperTool(off_prompt=True), - PromptSummaryTool(off_prompt=True), - FileManagerTool(), - ], - *args, - **kwargs, - # Add additional settings - ) - - # Override the run method to take a task and execute it using the Griptape agent - def run(self, task: str) -> str: - # Extract URL and filename from task (you can modify this parsing based on task structure) - url, filename = task.split( - "," - ) # Example of splitting task string - # Execute the Griptape agent with the task inputs - result = self.agent.run(url.strip(), filename.strip()) - # Return the final result as a string - return str(result) - - -# Example usage: -griptape_swarms_agent = GriptapeSwarmsAgent() -output = griptape_swarms_agent.run( - "https://griptape.ai, griptape.txt" -) -print(output) -``` - -### Key Components: -1. **GriptapeSwarmsAgent**: A custom class that inherits from the `SwarmsAgent` class and integrates the Griptape agent. -2. **run(task: str) -> str**: A method that takes a task string, processes it (e.g., splitting into a URL and filename), and runs the Griptape agent with the provided inputs. -3. **Griptape Tools**: The tools integrated into the Griptape agent (e.g., `WebScraperTool`, `PromptSummaryTool`, `FileManagerTool`) allow for web scraping, summarization, and file management. - -You can now easily plug this custom Griptape agent into the **Swarms Framework** and use it to run tasks! - - - - - - ## Understanding Swarms -### What is a Swarm? - A swarm refers to a group of more than two agents working collaboratively to achieve a common goal. These agents can be software entities, such as llms that interact with each other to perform complex tasks. The concept of a swarm is inspired by natural systems like ant colonies or bird flocks, where simple individual behaviors lead to complex group dynamics and problem-solving capabilities. ### How Swarm Architectures Facilitate Communication @@ -609,9 +525,6 @@ Swarm architectures are designed to establish and manage communication between a 3. **Sequential Communication**: Sequential swarms process tasks in a linear order, where each agent's output becomes the input for the next agent. This ensures that tasks with dependencies are handled in the correct sequence, maintaining the integrity of the workflow. -4. **Mesh Communication**: In mesh swarms, agents are fully connected, allowing any agent to communicate with any other agent. This setup provides high flexibility and redundancy, making it ideal for complex systems requiring dynamic interactions. - -5. **Federated Communication**: Federated swarms involve multiple independent swarms that collaborate by sharing information and results. Each swarm operates autonomously but can contribute to a larger task, enabling distributed problem-solving across different nodes. Swarm architectures leverage these communication patterns to ensure that agents work together efficiently, adapting to the specific requirements of the task at hand. By defining clear communication protocols and interaction models, swarm architectures enable the seamless orchestration of multiple agents, leading to enhanced performance and problem-solving capabilities. @@ -889,14 +802,12 @@ The `run` method returns the final output after all agents have processed the in from swarms import Agent, AgentRearrange -from swarm_models import Anthropic - # Initialize the director agent director = Agent( agent_name="Director", system_prompt="Directs the tasks for the workers", - llm=Anthropic(), + model_name="claude-2", max_loops=1, dashboard=False, streaming_on=True, @@ -912,7 +823,7 @@ director = Agent( worker1 = Agent( agent_name="Worker1", system_prompt="Generates a transcript for a youtube video on what swarms are", - llm=Anthropic(), + model_name="claude-2", max_loops=1, dashboard=False, streaming_on=True, @@ -927,7 +838,7 @@ worker1 = Agent( worker2 = Agent( agent_name="Worker2", system_prompt="Summarizes the transcript generated by Worker1", - llm=Anthropic(), + model_name="claude-2", max_loops=1, dashboard=False, streaming_on=True, @@ -1081,20 +992,12 @@ The `run` method returns the final output after all agents have processed the in ```python import os -from swarm_models import OpenAIChat from swarms import Agent, MixtureOfAgents -api_key = os.getenv("OPENAI_API_KEY") - -# Create individual agents with the OpenAIChat model -model = OpenAIChat( - openai_api_key=api_key, model_name="gpt-4", temperature=0.1 -) - # Agent 1: Financial Statement Analyzer agent1 = Agent( agent_name="FinancialStatementAnalyzer", - llm=model, + model_name="gpt-4o", system_prompt="""You are a Financial Statement Analyzer specializing in 10-K SEC reports. Your primary focus is on analyzing the financial statements, including the balance sheet, income statement, and cash flow statement. Key responsibilities: @@ -1120,7 +1023,7 @@ When analyzing, consider industry standards and compare the company's performanc # Agent 2: Risk Assessment Specialist agent2 = Agent( agent_name="RiskAssessmentSpecialist", - llm=model, + model_name="gpt-4o", system_prompt="""You are a Risk Assessment Specialist focusing on 10-K SEC reports. Your primary role is to identify, analyze, and evaluate potential risks disclosed in the report. Key responsibilities: @@ -1147,7 +1050,7 @@ Your analysis should provide a comprehensive overview of the company's risk land # Agent 3: Business Strategy Evaluator agent3 = Agent( agent_name="BusinessStrategyEvaluator", - llm=model, + model_name="gpt-4o", system_prompt="""You are a Business Strategy Evaluator specializing in analyzing 10-K SEC reports. Your focus is on assessing the company's overall strategy, market position, and future outlook. Key responsibilities: @@ -1175,7 +1078,7 @@ Your analysis should provide insights into the company's strategic direction, it # Aggregator Agent aggregator_agent = Agent( agent_name="10KReportAggregator", - llm=model, + model_name="gpt-4o", system_prompt="""You are the 10-K Report Aggregator, responsible for synthesizing and summarizing the analyses provided by the Financial Statement Analyzer, Risk Assessment Specialist, and Business Strategy Evaluator. Your goal is to create a comprehensive, coherent, and insightful summary of the 10-K SEC report. Key responsibilities: @@ -1265,9 +1168,8 @@ The `run` method returns a dictionary containing the outputs of each agent that ```python import os -from swarms import Agent +from swarms import Agent, SpreadSheetSwarm from swarm_models import OpenAIChat -from swarms.structs.spreadsheet_swarm import SpreadSheetSwarm # Define custom system prompts for each social media platform TWITTER_AGENT_SYS_PROMPT = """ @@ -1290,20 +1192,12 @@ EMAIL_AGENT_SYS_PROMPT = """ You are an Email marketing expert specializing in real estate. Your task is to write compelling email campaigns to promote properties, focusing on personalization, subject lines, and effective call-to-action strategies to drive conversions. """ -# Example usage: -api_key = os.getenv("OPENAI_API_KEY") - -# Model -model = OpenAIChat( - openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 -) - # Initialize your agents for different social media platforms agents = [ Agent( agent_name="Twitter-RealEstate-Agent", system_prompt=TWITTER_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o", max_loops=1, dynamic_temperature_enabled=True, saved_state_path="twitter_realestate_agent.json", @@ -1313,7 +1207,7 @@ agents = [ Agent( agent_name="Instagram-RealEstate-Agent", system_prompt=INSTAGRAM_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o", max_loops=1, dynamic_temperature_enabled=True, saved_state_path="instagram_realestate_agent.json", @@ -1323,7 +1217,7 @@ agents = [ Agent( agent_name="Facebook-RealEstate-Agent", system_prompt=FACEBOOK_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o", max_loops=1, dynamic_temperature_enabled=True, saved_state_path="facebook_realestate_agent.json", @@ -1333,7 +1227,7 @@ agents = [ Agent( agent_name="LinkedIn-RealEstate-Agent", system_prompt=LINKEDIN_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o", max_loops=1, dynamic_temperature_enabled=True, saved_state_path="linkedin_realestate_agent.json", @@ -1343,7 +1237,7 @@ agents = [ Agent( agent_name="Email-RealEstate-Agent", system_prompt=EMAIL_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o", max_loops=1, dynamic_temperature_enabled=True, saved_state_path="email_realestate_agent.json", @@ -1452,7 +1346,7 @@ The `run` method returns the output from the most relevant agent selected based ```python -from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm +from swarms import TreeAgent, Tree, ForestSwarm # Create agents with varying system prompts and dynamically generated distances/keywords agents_tree1 = [ diff --git a/api/agent_api.py b/api/agent_api.py new file mode 100644 index 00000000..d1968d9d --- /dev/null +++ b/api/agent_api.py @@ -0,0 +1,629 @@ +import os +from fastapi import ( + FastAPI, + HTTPException, + status, + Query, + BackgroundTasks, +) +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from typing import Optional, Dict, Any, List +from loguru import logger +import uvicorn +from datetime import datetime, timedelta +from uuid import UUID, uuid4 +from enum import Enum +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +import traceback + +from swarms import Agent +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Configure Loguru +logger.add( + "logs/api_{time}.log", + rotation="500 MB", + retention="10 days", + level="INFO", + format="{time} {level} {message}", + backtrace=True, + diagnose=True, +) + + +class AgentStatus(str, Enum): + """Enum for agent status.""" + + IDLE = "idle" + PROCESSING = "processing" + ERROR = "error" + MAINTENANCE = "maintenance" + + +class AgentConfig(BaseModel): + """Configuration model for creating a new agent.""" + + agent_name: str = Field(..., description="Name of the agent") + model_name: str = Field( + ..., + description="Name of the llm you want to use provided by litellm", + ) + description: str = Field( + default="", description="Description of the agent's purpose" + ) + system_prompt: str = Field( + ..., description="System prompt for the agent" + ) + model_name: str = Field( + default="gpt-4", description="Model name to use" + ) + temperature: float = Field( + default=0.1, + ge=0.0, + le=2.0, + description="Temperature for the model", + ) + max_loops: int = Field( + default=1, ge=1, description="Maximum number of loops" + ) + autosave: bool = Field( + default=True, description="Enable autosave" + ) + dashboard: bool = Field( + default=False, description="Enable dashboard" + ) + verbose: bool = Field( + default=True, description="Enable verbose output" + ) + dynamic_temperature_enabled: bool = Field( + default=True, description="Enable dynamic temperature" + ) + user_name: str = Field( + default="default_user", description="Username for the agent" + ) + retry_attempts: int = Field( + default=1, ge=1, description="Number of retry attempts" + ) + context_length: int = Field( + default=200000, ge=1000, description="Context length" + ) + output_type: str = Field( + default="string", description="Output type (string or json)" + ) + streaming_on: bool = Field( + default=False, description="Enable streaming" + ) + tags: List[str] = Field( + default_factory=list, + description="Tags for categorizing the agent", + ) + + +class AgentUpdate(BaseModel): + """Model for updating agent configuration.""" + + description: Optional[str] = None + system_prompt: Optional[str] = None + temperature: Optional[float] = None + max_loops: Optional[int] = None + tags: Optional[List[str]] = None + status: Optional[AgentStatus] = None + + +class AgentSummary(BaseModel): + """Summary model for agent listing.""" + + agent_id: UUID + agent_name: str + description: str + created_at: datetime + last_used: datetime + total_completions: int + tags: List[str] + status: AgentStatus + + +class AgentMetrics(BaseModel): + """Model for agent performance metrics.""" + + total_completions: int + average_response_time: float + error_rate: float + last_24h_completions: int + total_tokens_used: int + uptime_percentage: float + success_rate: float + peak_tokens_per_minute: int + + +class CompletionRequest(BaseModel): + """Model for completion requests.""" + + prompt: str = Field(..., description="The prompt to process") + agent_id: UUID = Field(..., description="ID of the agent to use") + max_tokens: Optional[int] = Field( + None, description="Maximum tokens to generate" + ) + temperature_override: Optional[float] = None + stream: bool = Field( + default=False, description="Enable streaming response" + ) + + +class CompletionResponse(BaseModel): + """Model for completion responses.""" + + agent_id: UUID + response: str + metadata: Dict[str, Any] + timestamp: datetime + processing_time: float + token_usage: Dict[str, int] + + +class AgentStore: + """Enhanced store for managing agents.""" + + def __init__(self): + self.agents: Dict[UUID, Agent] = {} + self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} + self.executor = ThreadPoolExecutor(max_workers=4) + self._ensure_directories() + + def _ensure_directories(self): + """Ensure required directories exist.""" + Path("logs").mkdir(exist_ok=True) + Path("states").mkdir(exist_ok=True) + + async def create_agent(self, config: AgentConfig) -> UUID: + """Create a new agent with the given configuration.""" + try: + + agent = Agent( + agent_name=config.agent_name, + system_prompt=config.system_prompt, + model_name=config.model_name, + max_loops=config.max_loops, + autosave=config.autosave, + dashboard=config.dashboard, + verbose=config.verbose, + dynamic_temperature_enabled=config.dynamic_temperature_enabled, + saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", + user_name=config.user_name, + retry_attempts=config.retry_attempts, + context_length=config.context_length, + return_step_meta=True, + output_type="str", + streaming_on=config.streaming_on, + ) + + agent_id = uuid4() + self.agents[agent_id] = agent + self.agent_metadata[agent_id] = { + "description": config.description, + "created_at": datetime.utcnow(), + "last_used": datetime.utcnow(), + "total_completions": 0, + "tags": config.tags, + "total_tokens": 0, + "error_count": 0, + "response_times": [], + "status": AgentStatus.IDLE, + "start_time": datetime.utcnow(), + "downtime": timedelta(), + "successful_completions": 0, + } + + logger.info(f"Created agent with ID: {agent_id}") + return agent_id + + except Exception as e: + logger.error(f"Error creating agent: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create agent: {str(e)}", + ) + + async def get_agent(self, agent_id: UUID) -> Agent: + """Retrieve an agent by ID.""" + agent = self.agents.get(agent_id) + if not agent: + logger.error(f"Agent not found: {agent_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + return agent + + async def update_agent( + self, agent_id: UUID, update: AgentUpdate + ) -> None: + """Update agent configuration.""" + agent = await self.get_agent(agent_id) + metadata = self.agent_metadata[agent_id] + + if update.system_prompt: + agent.system_prompt = update.system_prompt + if update.temperature is not None: + agent.llm.temperature = update.temperature + if update.max_loops is not None: + agent.max_loops = update.max_loops + if update.tags is not None: + metadata["tags"] = update.tags + if update.description is not None: + metadata["description"] = update.description + if update.status is not None: + metadata["status"] = update.status + if update.status == AgentStatus.MAINTENANCE: + metadata["downtime"] += ( + datetime.utcnow() - metadata["last_used"] + ) + + logger.info(f"Updated agent {agent_id}") + + async def list_agents( + self, + tags: Optional[List[str]] = None, + status: Optional[AgentStatus] = None, + ) -> List[AgentSummary]: + """List all agents, optionally filtered by tags and status.""" + summaries = [] + for agent_id, agent in self.agents.items(): + metadata = self.agent_metadata[agent_id] + + # Apply filters + if tags and not any( + tag in metadata["tags"] for tag in tags + ): + continue + if status and metadata["status"] != status: + continue + + summaries.append( + AgentSummary( + agent_id=agent_id, + agent_name=agent.agent_name, + description=metadata["description"], + created_at=metadata["created_at"], + last_used=metadata["last_used"], + total_completions=metadata["total_completions"], + tags=metadata["tags"], + status=metadata["status"], + ) + ) + return summaries + + async def get_agent_metrics(self, agent_id: UUID) -> AgentMetrics: + """Get performance metrics for an agent.""" + metadata = self.agent_metadata[agent_id] + response_times = metadata["response_times"] + + # Calculate metrics + total_time = datetime.utcnow() - metadata["start_time"] + uptime = total_time - metadata["downtime"] + uptime_percentage = ( + uptime.total_seconds() / total_time.total_seconds() + ) * 100 + + success_rate = ( + metadata["successful_completions"] + / metadata["total_completions"] + * 100 + if metadata["total_completions"] > 0 + else 0 + ) + + return AgentMetrics( + total_completions=metadata["total_completions"], + average_response_time=( + sum(response_times) / len(response_times) + if response_times + else 0 + ), + error_rate=( + metadata["error_count"] + / metadata["total_completions"] + if metadata["total_completions"] > 0 + else 0 + ), + last_24h_completions=sum( + 1 + for t in response_times + if (datetime.utcnow() - t).days < 1 + ), + total_tokens_used=metadata["total_tokens"], + uptime_percentage=uptime_percentage, + success_rate=success_rate, + peak_tokens_per_minute=max( + metadata.get("tokens_per_minute", [0]) + ), + ) + + async def clone_agent( + self, agent_id: UUID, new_name: str + ) -> UUID: + """Clone an existing agent with a new name.""" + original_agent = await self.get_agent(agent_id) + original_metadata = self.agent_metadata[agent_id] + + config = AgentConfig( + agent_name=new_name, + description=f"Clone of {original_agent.agent_name}", + system_prompt=original_agent.system_prompt, + model_name=original_agent.llm.model_name, + temperature=original_agent.llm.temperature, + max_loops=original_agent.max_loops, + tags=original_metadata["tags"], + ) + + return await self.create_agent(config) + + async def delete_agent(self, agent_id: UUID) -> None: + """Delete an agent.""" + if agent_id not in self.agents: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + + # Clean up any resources + agent = self.agents[agent_id] + if agent.autosave and os.path.exists(agent.saved_state_path): + os.remove(agent.saved_state_path) + + del self.agents[agent_id] + del self.agent_metadata[agent_id] + logger.info(f"Deleted agent {agent_id}") + + async def process_completion( + self, + agent: Agent, + prompt: str, + agent_id: UUID, + max_tokens: Optional[int] = None, + temperature_override: Optional[float] = None, + ) -> CompletionResponse: + """Process a completion request using the specified agent.""" + start_time = datetime.utcnow() + metadata = self.agent_metadata[agent_id] + + try: + # Update agent status + metadata["status"] = AgentStatus.PROCESSING + metadata["last_used"] = start_time + + # Apply temporary overrides if specified + original_temp = agent.llm.temperature + if temperature_override is not None: + agent.llm.temperature = temperature_override + + # Process the completion + response = agent.run(prompt) + + # Reset overrides + if temperature_override is not None: + agent.llm.temperature = original_temp + + # Update metrics + processing_time = ( + datetime.utcnow() - start_time + ).total_seconds() + metadata["response_times"].append(processing_time) + metadata["total_completions"] += 1 + metadata["successful_completions"] += 1 + + # Estimate token usage (this is a rough estimate) + prompt_tokens = len(prompt.split()) * 1.3 + completion_tokens = len(response.split()) * 1.3 + total_tokens = int(prompt_tokens + completion_tokens) + metadata["total_tokens"] += total_tokens + + # Update tokens per minute tracking + current_minute = datetime.utcnow().replace( + second=0, microsecond=0 + ) + if "tokens_per_minute" not in metadata: + metadata["tokens_per_minute"] = {} + metadata["tokens_per_minute"][current_minute] = ( + metadata["tokens_per_minute"].get(current_minute, 0) + + total_tokens + ) + + return CompletionResponse( + agent_id=agent_id, + response=response, + metadata={ + "agent_name": agent.agent_name, + "model_name": agent.llm.model_name, + "temperature": agent.llm.temperature, + }, + timestamp=datetime.utcnow(), + processing_time=processing_time, + token_usage={ + "prompt_tokens": int(prompt_tokens), + "completion_tokens": int(completion_tokens), + "total_tokens": total_tokens, + }, + ) + + except Exception as e: + metadata["error_count"] += 1 + metadata["status"] = AgentStatus.ERROR + logger.error( + f"Error in completion processing: {str(e)}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing completion: {str(e)}", + ) + finally: + metadata["status"] = AgentStatus.IDLE + + +class SwarmsAPI: + """Enhanced API class for Swarms agent integration.""" + + def __init__(self): + self.app = FastAPI( + title="Swarms Agent API", + description="Production-grade API for Swarms agent interaction", + version="1.0.0", + docs_url="/v1/docs", + redoc_url="/v1/redoc", + ) + self.store = AgentStore() + # Configure CORS + self.app.add_middleware( + CORSMiddleware, + allow_origins=[ + "*" + ], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + self._setup_routes() + + def _setup_routes(self): + """Set up API routes.""" + + @self.app.post("/v1/agent", response_model=Dict[str, UUID]) + async def create_agent(config: AgentConfig): + """Create a new agent with the specified configuration.""" + agent_id = await self.store.create_agent(config) + return {"agent_id": agent_id} + + @self.app.get("/v1/agents", response_model=List[AgentSummary]) + async def list_agents( + tags: Optional[List[str]] = Query(None), + status: Optional[AgentStatus] = None, + ): + """List all agents, optionally filtered by tags and status.""" + return await self.store.list_agents(tags, status) + + @self.app.patch( + "/v1/agent/{agent_id}", response_model=Dict[str, str] + ) + async def update_agent(agent_id: UUID, update: AgentUpdate): + """Update an existing agent's configuration.""" + await self.store.update_agent(agent_id, update) + return {"status": "updated"} + + @self.app.get( + "/v1/agent/{agent_id}/metrics", + response_model=AgentMetrics, + ) + async def get_agent_metrics(agent_id: UUID): + """Get performance metrics for a specific agent.""" + return await self.store.get_agent_metrics(agent_id) + + @self.app.post( + "/v1/agent/{agent_id}/clone", + response_model=Dict[str, UUID], + ) + async def clone_agent(agent_id: UUID, new_name: str): + """Clone an existing agent with a new name.""" + new_id = await self.store.clone_agent(agent_id, new_name) + return {"agent_id": new_id} + + @self.app.delete("/v1/agent/{agent_id}") + async def delete_agent(agent_id: UUID): + """Delete an agent.""" + await self.store.delete_agent(agent_id) + return {"status": "deleted"} + + @self.app.post( + "/v1/agent/completions", response_model=CompletionResponse + ) + async def create_completion( + request: CompletionRequest, + background_tasks: BackgroundTasks, + ): + """Process a completion request with the specified agent.""" + try: + agent = await self.store.get_agent(request.agent_id) + + # Process completion + response = await self.store.process_completion( + agent, + request.prompt, + request.agent_id, + request.max_tokens, + request.temperature_override, + ) + + # Schedule background cleanup + background_tasks.add_task( + self._cleanup_old_metrics, request.agent_id + ) + + return response + + except Exception as e: + logger.error(f"Error processing completion: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing completion: {str(e)}", + ) + + @self.app.get("/v1/agent/{agent_id}/status") + async def get_agent_status(agent_id: UUID): + """Get the current status of an agent.""" + metadata = self.store.agent_metadata.get(agent_id) + if not metadata: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + return { + "agent_id": agent_id, + "status": metadata["status"], + "last_used": metadata["last_used"], + "total_completions": metadata["total_completions"], + "error_count": metadata["error_count"], + } + + async def _cleanup_old_metrics(self, agent_id: UUID): + """Clean up old metrics data to prevent memory bloat.""" + metadata = self.store.agent_metadata.get(agent_id) + if metadata: + # Keep only last 24 hours of response times + cutoff = datetime.utcnow() - timedelta(days=1) + metadata["response_times"] = [ + t + for t in metadata["response_times"] + if isinstance(t, (int, float)) + and t > cutoff.timestamp() + ] + + # Clean up old tokens per minute data + if "tokens_per_minute" in metadata: + metadata["tokens_per_minute"] = { + k: v + for k, v in metadata["tokens_per_minute"].items() + if k > cutoff + } + + +def create_app() -> FastAPI: + """Create and configure the FastAPI application.""" + api = SwarmsAPI() + return api.app + + +if __name__ == "__main__": + # Configure uvicorn logging + logger.info("API Starting") + uvicorn.run( + "main:create_app", + host="0.0.0.0", + port=8000, + reload=True, + workers=4, + ) diff --git a/api/agent_api_test.py b/api/agent_api_test.py new file mode 100644 index 00000000..066efc4f --- /dev/null +++ b/api/agent_api_test.py @@ -0,0 +1,107 @@ +import requests +from loguru import logger +import time + +# Configure loguru +logger.add( + "api_tests_{time}.log", + rotation="100 MB", + level="DEBUG", + format="{time} {level} {message}", +) + +BASE_URL = "http://localhost:8000/v1" + + +def test_create_agent(): + """Test creating a new agent.""" + logger.info("Testing agent creation") + + payload = { + "agent_name": "Test Agent", + "system_prompt": "You are a helpful assistant", + "model_name": "gpt-4", + "description": "Test agent", + "tags": ["test"], + } + + response = requests.post(f"{BASE_URL}/agent", json=payload) + logger.debug(f"Create response: {response.json()}") + + if response.status_code == 200: + logger.success("Successfully created agent") + return response.json()["agent_id"] + else: + logger.error(f"Failed to create agent: {response.text}") + return None + + +def test_list_agents(): + """Test listing all agents.""" + logger.info("Testing agent listing") + + response = requests.get(f"{BASE_URL}/agents") + logger.debug(f"List response: {response.json()}") + + if response.status_code == 200: + logger.success(f"Found {len(response.json())} agents") + else: + logger.error(f"Failed to list agents: {response.text}") + + +def test_completion(agent_id): + """Test running a completion.""" + logger.info("Testing completion") + + payload = { + "prompt": "What is the weather like today?", + "agent_id": agent_id, + } + + response = requests.post( + f"{BASE_URL}/agent/completions", json=payload + ) + logger.debug(f"Completion response: {response.json()}") + + if response.status_code == 200: + logger.success("Successfully got completion") + else: + logger.error(f"Failed to get completion: {response.text}") + + +def test_delete_agent(agent_id): + """Test deleting an agent.""" + logger.info("Testing agent deletion") + + response = requests.delete(f"{BASE_URL}/agent/{agent_id}") + logger.debug(f"Delete response: {response.json()}") + + if response.status_code == 200: + logger.success("Successfully deleted agent") + else: + logger.error(f"Failed to delete agent: {response.text}") + + +def run_tests(): + """Run all tests in sequence.""" + logger.info("Starting API tests") + + # Create agent and get ID + agent_id = test_create_agent() + if not agent_id: + logger.error("Cannot continue tests without agent ID") + return + + # Wait a bit for agent to be ready + time.sleep(1) + + # Run other tests + test_list_agents() + test_completion(agent_id) + test_delete_agent(agent_id) + + logger.info("Tests completed") + + +if __name__ == "__main__": + run_tests() diff --git a/byte.py b/byte.py new file mode 100644 index 00000000..d0a5a92f --- /dev/null +++ b/byte.py @@ -0,0 +1,898 @@ +from enum import Enum +from typing import Union, Optional +import io +from PIL import Image +import numpy as np +import torch +import struct + + +from enum import auto +from typing import List, Dict, Tuple +import wave +from dataclasses import dataclass +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger +from einops import rearrange +from torch import Tensor + + +@dataclass +class ModelConfig: + """Configuration for the enhanced BytePredictor model.""" + + vocab_size: int = 256 # Standard byte range + hidden_size: int = 1024 + num_layers: int = 12 + num_key_value_heads: int = 8 # For multi-query attention + num_query_heads: int = 32 # More query heads than kv heads + dropout: float = 0.1 + max_sequence_length: int = 8192 + rope_theta: float = 10000.0 + layer_norm_eps: float = 1e-5 + vocab_parallel: bool = False + qk_norm: bool = True + qk_norm_scale: float = None + attention_bias: bool = False + + +class MultiQueryAttention(nn.Module): + """Fixed Multi-Query Attention implementation.""" + + def __init__(self, config: ModelConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_query_heads = config.num_query_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_query_heads + self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5) + + self.q_proj = nn.Linear( + config.hidden_size, config.num_query_heads * self.head_dim + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + ) + self.o_proj = nn.Linear( + config.num_query_heads * self.head_dim, config.hidden_size + ) + + self.qk_norm = config.qk_norm + if self.qk_norm: + self.q_norm = nn.LayerNorm(self.head_dim) + self.k_norm = nn.LayerNorm(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_length, _ = hidden_states.shape + + # Project and reshape + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to [seq_len, batch, heads, head_dim] + q = q.view( + batch_size, + seq_length, + self.num_query_heads, + self.head_dim, + ).permute(1, 0, 2, 3) + k = k.view( + batch_size, + seq_length, + self.num_key_value_heads, + self.head_dim, + ).permute(1, 0, 2, 3) + v = v.view( + batch_size, + seq_length, + self.num_key_value_heads, + self.head_dim, + ).permute(1, 0, 2, 3) + + # Apply rotary embeddings + # q, k = self.rotary(q, k, seq_length) + + # Apply QK normalization if enabled + if self.qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Handle MQA head expansion + if self.num_key_value_heads != self.num_query_heads: + k = k.repeat_interleave( + self.num_query_heads // self.num_key_value_heads, + dim=2, + ) + v = v.repeat_interleave( + self.num_query_heads // self.num_key_value_heads, + dim=2, + ) + + # Compute attention + # Reshape for matmul: [batch, heads, seq_length, head_dim] + q = q.permute(1, 2, 0, 3) + k = k.permute(1, 2, 0, 3) + v = v.permute(1, 2, 0, 3) + + attn_weights = ( + torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale + ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, dim=-1) + + output = torch.matmul(attn_weights, v) + + # Reshape back to [batch, seq_length, hidden_size] + output = ( + output.transpose(1, 2) + .contiguous() + .view(batch_size, seq_length, -1) + ) + output = self.o_proj(output) + + return output + + +class EnhancedBytePredictor(nn.Module): + """Enhanced byte prediction model with state-of-the-art features.""" + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + # Token embeddings + self.tok_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size + ) + + # Transformer layers + self.layers = nn.ModuleList( + [ + nn.ModuleDict( + { + "attention": MultiQueryAttention(config), + "attention_norm": nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ), + "feed_forward": nn.Sequential( + nn.Linear( + config.hidden_size, + 4 * config.hidden_size, + ), + nn.GELU(), + nn.Linear( + 4 * config.hidden_size, + config.hidden_size, + ), + ), + "feed_forward_norm": nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ), + } + ) + for _ in range(config.num_layers) + ] + ) + + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.output = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module: nn.Module) -> None: + """Initialize weights with scaled normal distribution.""" + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the model. + + Args: + input_ids: Tensor of shape (batch_size, sequence_length) + attention_mask: Optional attention mask + + Returns: + Tensor of logits with shape (batch_size, sequence_length, vocab_size) + """ + hidden_states = self.tok_embeddings(input_ids) + + # Create causal mask if needed + if attention_mask is None: + attention_mask = torch.triu( + torch.ones( + (input_ids.size(1), input_ids.size(1)), + device=input_ids.device, + dtype=torch.bool, + ), + diagonal=1, + ) + attention_mask = attention_mask.masked_fill( + attention_mask == 1, float("-inf") + ) + + # Apply transformer layers + for layer in self.layers: + # Attention block + hidden_states = hidden_states + layer["attention"]( + layer["attention_norm"](hidden_states), attention_mask + ) + + # Feed-forward block + hidden_states = hidden_states + layer["feed_forward"]( + layer["feed_forward_norm"](hidden_states) + ) + + hidden_states = self.norm(hidden_states) + logits = self.output(hidden_states) + + return logits + + def compute_loss( + self, + input_ids: torch.Tensor, + target_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute cross entropy loss. + + Args: + input_ids: Input token ids + target_ids: Target token ids + attention_mask: Optional attention mask + + Returns: + Loss value + """ + logits = self(input_ids, attention_mask) + loss = F.cross_entropy( + rearrange(logits, "b s v -> (b s) v"), + rearrange(target_ids, "b s -> (b s)"), + ) + return loss + + @torch.no_grad() + def _generate( + self, + input_ids: torch.Tensor, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: float = 1.0, + ) -> torch.Tensor: + """ + Generate new tokens autoregressively. + + Args: + input_ids: Starting sequence + max_new_tokens: Number of tokens to generate + temperature: Sampling temperature + top_k: K for top-k sampling + top_p: P for nucleus sampling + repetition_penalty: Penalty for repeating tokens + + Returns: + Generated sequence + """ + batch_size, seq_length = input_ids.shape + generated = input_ids.clone() + + for _ in range(max_new_tokens): + if generated.size(1) >= self.config.max_sequence_length: + break + + # Forward pass + logits = self(generated)[:, -1, :] + + # Apply temperature + logits = logits / temperature + + # Apply repetition penalty + if repetition_penalty != 1.0: + for i in range(batch_size): + for token_id in set(generated[i].tolist()): + logits[i, token_id] /= repetition_penalty + + # Apply top-k sampling + if top_k is not None: + indices_to_remove = ( + logits + < torch.topk(logits, top_k)[0][..., -1, None] + ) + logits[indices_to_remove] = float("-inf") + + # Apply nucleus (top-p) sampling + if top_p is not None: + sorted_logits, sorted_indices = torch.sort( + logits, descending=True + ) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = ( + sorted_indices_to_remove[..., :-1].clone() + ) + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = torch.zeros_like( + logits, dtype=torch.bool + ) + indices_to_remove.scatter_( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = float("-inf") + + # Sample next token + probs = F.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + # Append to sequence + generated = torch.cat([generated, next_token], dim=1) + + return generated + + def generate( + self, + input_ids: torch.Tensor, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: float = 1.0, + ): + tensor_data = self._generate( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + return tensor_to_data(tensor_data) + + +# import torch +# from typing import Optional + + +class DataType(Enum): + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + BINARY = "binary" + + +class ByteDetokenizer: + """Utility class for converting model output bytes back to original data formats.""" + + @staticmethod + def tensor_to_bytes(tensor: torch.Tensor) -> bytes: + """Convert model output tensor to bytes.""" + # Convert logits/probabilities to byte values + if tensor.dim() > 1: + # If we have logits, convert to byte indices + byte_indices = tensor.argmax(dim=-1) + else: + byte_indices = tensor + + # Convert to Python bytes + return bytes( + byte_indices.cpu().numpy().astype(np.uint8).tolist() + ) + + @staticmethod + def decode_text(byte_sequence: bytes) -> str: + """Convert bytes to text.""" + try: + return byte_sequence.decode("utf-8") + except UnicodeDecodeError: + # Try with error handling + return byte_sequence.decode("utf-8", errors="replace") + + @staticmethod + def decode_image( + byte_sequence: bytes, + mode: str = "RGB", + size: Optional[tuple] = None, + ) -> Image.Image: + """Convert bytes to image. + + Args: + byte_sequence: Raw image bytes + mode: Image mode (RGB, RGBA, L, etc.) + size: Optional tuple of (width, height) + """ + try: + # Try to load as-is first (for standard image formats) + img = Image.open(io.BytesIO(byte_sequence)) + if size: + img = img.resize(size) + return img + except: + # If failed, assume raw pixel data + if not size: + # Try to determine size from byte count + pixel_count = len(byte_sequence) // len(mode) + size = ( + int(np.sqrt(pixel_count)), + int(np.sqrt(pixel_count)), + ) + + # Convert raw bytes to pixel array + pixels = np.frombuffer(byte_sequence, dtype=np.uint8) + pixels = pixels.reshape((*size, len(mode))) + + return Image.fromarray(pixels, mode=mode) + + @staticmethod + def decode_audio( + byte_sequence: bytes, + sample_rate: int = 44100, + channels: int = 2, + sample_width: int = 2, + ) -> np.ndarray: + """Convert bytes to audio samples. + + Args: + byte_sequence: Raw audio bytes + sample_rate: Audio sample rate in Hz + channels: Number of audio channels + sample_width: Bytes per sample (1, 2, or 4) + """ + # Determine format string based on sample width + format_str = { + 1: "b", # signed char + 2: "h", # short + 4: "i", # int + }[sample_width] + + # Unpack bytes to samples + sample_count = len(byte_sequence) // (channels * sample_width) + samples = struct.unpack( + f"<{sample_count * channels}{format_str}", byte_sequence + ) + + # Reshape to [samples, channels] + return np.array(samples).reshape(-1, channels) + + def decode_data( + self, + model_output: Union[torch.Tensor, bytes], + data_type: DataType, + **kwargs, + ) -> Union[str, Image.Image, np.ndarray, bytes]: + """Main method to decode model output to desired format. + + Args: + model_output: Either tensor from model or raw bytes + data_type: Type of data to decode to + **kwargs: Additional parameters for specific decoders + + Returns: + Decoded data in specified format + """ + # Convert tensor to bytes if needed + if isinstance(model_output, torch.Tensor): + byte_sequence = self.tensor_to_bytes(model_output) + else: + byte_sequence = model_output + + # Decode based on type + if data_type == DataType.TEXT: + return self.decode_text(byte_sequence) + elif data_type == DataType.IMAGE: + return self.decode_image(byte_sequence, **kwargs) + elif data_type == DataType.AUDIO: + return self.decode_audio(byte_sequence, **kwargs) + elif data_type == DataType.VIDEO: + raise NotImplementedError( + "Video decoding not yet implemented" + ) + else: # BINARY + return byte_sequence + + +# Usage example + + +class Modality(Enum): + TEXT = auto() + IMAGE = auto() + AUDIO = auto() + VIDEO = auto() + BINARY = auto() + MULTIMODAL = auto() + + +@dataclass +class ModalityInfo: + """Information about detected modality.""" + + modality: Modality + confidence: float + metadata: Dict[str, any] + sub_modalities: Optional[List["ModalityInfo"]] = None + + +class ModalityDetector: + """Detects data modalities from byte sequences.""" + + # Common file signatures (magic numbers) + SIGNATURES = { + # Images + b"\xFF\xD8\xFF": "JPEG", + b"\x89PNG\r\n\x1a\n": "PNG", + b"GIF87a": "GIF", + b"GIF89a": "GIF", + b"RIFF": "WEBP", + # Audio + b"RIFF....WAVE": "WAV", + b"ID3": "MP3", + b"\xFF\xFB": "MP3", + b"OggS": "OGG", + # Video + b"\x00\x00\x00\x18ftypmp42": "MP4", + b"\x00\x00\x00\x1Cftypav01": "MP4", + b"\x1A\x45\xDF\xA3": "WEBM", + } + + def __init__(self): + self.magic = magic.Magic(mime=True) + + def _check_text_probability(self, data: bytes) -> float: + """Estimate probability that data is text.""" + # Check if data is valid UTF-8 + try: + data.decode("utf-8") + # Count printable ASCII characters + printable = sum(1 for b in data if 32 <= b <= 126) + return printable / len(data) + except UnicodeDecodeError: + return 0.0 + + def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]: + """Check if data is a valid image and extract metadata.""" + try: + with io.BytesIO(data) as bio: + img = Image.open(bio) + return True, { + "format": img.format, + "size": img.size, + "mode": img.mode, + } + except: + return False, {} + + def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]: + """Check if data is valid audio and extract metadata.""" + try: + with io.BytesIO(data) as bio: + # Try to parse as WAV + with wave.open(bio) as wav: + return True, { + "channels": wav.getnchannels(), + "sample_width": wav.getsampwidth(), + "framerate": wav.getframerate(), + "frames": wav.getnframes(), + } + except: + # Check for other audio signatures + for sig in [b"ID3", b"\xFF\xFB", b"OggS"]: + if data.startswith(sig): + return True, {"format": "compressed_audio"} + return False, {} + + def _detect_boundaries( + self, data: bytes + ) -> List[Tuple[int, int, Modality]]: + """Detect boundaries between different modalities in the data.""" + boundaries = [] + current_pos = 0 + + while current_pos < len(data): + # Look for known signatures + for sig, format_type in self.SIGNATURES.items(): + if data[current_pos:].startswith(sig): + # Found a signature, determine its length + if format_type in ["JPEG", "PNG", "GIF"]: + # Find image end + try: + with io.BytesIO( + data[current_pos:] + ) as bio: + img = Image.open(bio) + img.verify() + size = bio.tell() + boundaries.append( + ( + current_pos, + current_pos + size, + Modality.IMAGE, + ) + ) + current_pos += size + continue + except: + pass + + # Check for text sections + text_prob = self._check_text_probability( + data[current_pos : current_pos + 1024] + ) + if text_prob > 0.8: + # Look for end of text section + end_pos = current_pos + 1 + while end_pos < len(data): + if ( + self._check_text_probability( + data[end_pos : end_pos + 32] + ) + < 0.5 + ): + break + end_pos += 1 + boundaries.append( + (current_pos, end_pos, Modality.TEXT) + ) + current_pos = end_pos + continue + + current_pos += 1 + + return boundaries + + def detect_modality(self, data: bytes) -> ModalityInfo: + """Detect modality of byte sequence.""" + # First check for single modality + mime_type = self.magic.from_buffer(data) + + # Check text + text_prob = self._check_text_probability(data) + if text_prob > 0.9: + return ModalityInfo( + modality=Modality.TEXT, + confidence=text_prob, + metadata={"mime_type": mime_type}, + ) + + # Check image + is_image, image_meta = self._check_image_validity(data) + if is_image: + return ModalityInfo( + modality=Modality.IMAGE, + confidence=1.0, + metadata={**image_meta, "mime_type": mime_type}, + ) + + # Check audio + is_audio, audio_meta = self._check_audio_validity(data) + if is_audio: + return ModalityInfo( + modality=Modality.AUDIO, + confidence=1.0, + metadata={**audio_meta, "mime_type": mime_type}, + ) + + # Check for multimodal content + boundaries = self._detect_boundaries(data) + if len(boundaries) > 1: + sub_modalities = [] + for start, end, modality in boundaries: + chunk_data = data[start:end] + sub_info = self.detect_modality(chunk_data) + if sub_info.modality != Modality.BINARY: + sub_modalities.append(sub_info) + + if sub_modalities: + return ModalityInfo( + modality=Modality.MULTIMODAL, + confidence=0.8, + metadata={"mime_type": "multipart/mixed"}, + sub_modalities=sub_modalities, + ) + + # Default to binary + return ModalityInfo( + modality=Modality.BINARY, + confidence=0.5, + metadata={"mime_type": mime_type}, + ) + + def split_modalities( + self, data: bytes + ) -> List[Tuple[Modality, bytes, Dict]]: + """Split multimodal data into separate modalities.""" + boundaries = self._detect_boundaries(data) + result = [] + + for start, end, modality in boundaries: + chunk = data[start:end] + info = self.detect_modality(chunk) + result.append((modality, chunk, info.metadata)) + + return result + + +class AutoDetectBytesDecoder: + """Decoder that automatically detects and decodes different modalities.""" + + def __init__(self): + self.detector = ModalityDetector() + self.text_decoder = ByteDetokenizer() # From previous example + + def decode( + self, data: bytes + ) -> Union[str, Image.Image, np.ndarray, List[any]]: + """Automatically detect and decode byte sequence.""" + info = self.detector.detect_modality(data) + + if info.modality == Modality.MULTIMODAL: + # Handle multimodal content + parts = self.detector.split_modalities(data) + return [ + self.decode(chunk) for modality, chunk, _ in parts + ] + + if info.modality == Modality.TEXT: + return self.text_decoder.decode_text(data) + elif info.modality == Modality.IMAGE: + return self.text_decoder.decode_image(data) + elif info.modality == Modality.AUDIO: + return self.text_decoder.decode_audio(data) + else: + return data + + +# # Example usage +# def demo_auto_detection(): +# """Demonstrate auto modality detection.""" +# # Create mixed content +# text = "Hello, World!".encode('utf-8') + +# # Create a small test image +# img = Image.new('RGB', (100, 100), color='red') +# img_bytes = io.BytesIO() +# img.save(img_bytes, format='PNG') + +# # Combine into multimodal content +# mixed_content = text + img_bytes.getvalue() + +# # Initialize decoder +# decoder = AutoDetectBytesDecoder() + +# # Decode +# result = decoder.decode(mixed_content) + +# if isinstance(result, list): +# print("Detected multimodal content:") +# for i, part in enumerate(result): +# print(f"Part {i+1}: {type(part)}") + +# if __name__ == "__main__": +# demo_auto_detection() + + +def tensor_to_data(tensor: Tensor): + byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor) + + # Initialize auto-detector + decoder = AutoDetectBytesDecoder() + + # Decode with automatic detection + result = decoder.decode(byte_sequence) + + return result + + +def demo_byte_predictor(): + """Demo with smaller dimensions to test.""" + # Initialize model configuration with adjusted dimensions + config = ModelConfig( + vocab_size=256, + hidden_size=128, # Smaller for testing + num_layers=2, # Fewer layers for testing + num_key_value_heads=2, + num_query_heads=4, + dropout=0.1, + max_sequence_length=1024, + ) + + # Initialize model + model = EnhancedBytePredictor(config) + logger.info("Model initialized") + + # Move to GPU if available + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + model = model.to(device) + logger.info(f"Using device: {device}") + + # Create sample input data + batch_size = 2 + seq_length = 16 # Shorter sequence for testing + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_length), device=device + ) + logger.info(f"Created input tensor of shape: {input_ids.shape}") + + # Test forward pass + try: + logits = model(input_ids) + logger.info( + f"Forward pass successful! Output shape: {logits.shape}" + ) + + # Test loss computation + target_ids = torch.randint( + 0, + config.vocab_size, + (batch_size, seq_length), + device=device, + ) + loss = model.compute_loss(input_ids, target_ids) + logger.info( + f"Loss computation successful! Loss value: {loss.item():.4f}" + ) + + # Test generation + prompt = torch.randint( + 0, + config.vocab_size, + (1, 4), # Very short prompt for testing + device=device, + ) + generated = model.generate( + prompt, max_new_tokens=8, temperature=0.8, top_k=50 + ) + logger.info( + f"Generation successful! Generated shape: {generated.shape}" + ) + + except Exception as e: + logger.error(f"Error during execution: {str(e)}") + raise + + +if __name__ == "__main__": + # Set up logging + # logger.remove() # Remove default handler + # logger.add(sys.stderr, format="{time:HH:mm:ss} | {level} | {message}") + + demo_byte_predictor() diff --git a/docs/assets/css/extra.css b/docs/assets/css/extra.css index 8a8a758f..a9967e01 100644 --- a/docs/assets/css/extra.css +++ b/docs/assets/css/extra.css @@ -8,7 +8,7 @@ display: table; } -/* Dark mode */ +/* Dark mode [data-md-color-scheme="slate"] { --md-default-bg-color: black; } @@ -24,4 +24,4 @@ .md-header.md-header--shadow { color: black; -} \ No newline at end of file +} */ \ No newline at end of file diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 53b4d273..f702f1c5 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -157,6 +157,7 @@ nav: # - Build Custom Agents: "swarms/structs/diy_your_own_agent.md" - Agent Architecture: "swarms/framework/agents_explained.md" - Complete Agent API: "swarms/structs/agent.md" + - OpenAI Assistant: "swarms/agents/openai_assistant.md" - Create and Run Agents from YAML: "swarms/agents/create_agents_yaml.md" - Integrating External Agents from Griptape, Langchain, etc: "swarms/agents/external_party_agents.md" - Tools: diff --git a/docs/swarms/agents/openai_assistant.md b/docs/swarms/agents/openai_assistant.md new file mode 100644 index 00000000..d5f3b8bf --- /dev/null +++ b/docs/swarms/agents/openai_assistant.md @@ -0,0 +1,135 @@ +# OpenAI Assistant + +The OpenAI Assistant class provides a wrapper around OpenAI's Assistants API, integrating it with the swarms framework. + +## Overview + +The `OpenAIAssistant` class allows you to create and interact with OpenAI Assistants, providing a simple interface for: + +- Creating assistants with specific roles and capabilities +- Adding custom functions that the assistant can call +- Managing conversation threads +- Handling tool calls and function execution +- Getting responses from the assistant + +## Insstallation + +```bash +pip install swarms +``` +## Basic Usage + +```python + +from swarms import OpenAIAssistant + +#Create an assistant +assistant = OpenAIAssistant( + name="Math Tutor", + instructions="You are a helpful math tutor.", + model="gpt-4o", + tools=[{"type": "code_interpreter"}] +) + +#Run a Task +response = assistant.run("Solve the equation: 3x + 11 = 14") +print(response) + +# Continue the conversation in the same thread +follow_up = assistant.run("Now explain how you solved it") +print(follow_up) +``` + +## Function Calling + +The assistant supports custom function integration: + +```python + +def get_weather(location: str, unit: str = "celsius") -> str: + # Mock weather function + return f"The weather in {location} is 22 degrees {unit}" + +# Add function to assistant +assistant.add_function( + description="Get the current weather in a location", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius" + } + }, + "required": ["location"] + } +) +``` + +## API Reference + +### Constructor + +```python +OpenAIAssistant( + name: str, + instructions: Optional[str] = None, + model: str = "gpt-4o", + tools: Optional[List[Dict[str, Any]]] = None, + file_ids: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + functions: Optional[List[Dict[str, Any]]] = None, +) +``` + +### Methods + +#### run(task: str) -> str +Sends a task to the assistant and returns its response. The conversation thread is maintained between calls. + +#### add_function(func: Callable, description: str, parameters: Dict[str, Any]) -> None +Adds a callable function that the assistant can use during conversations. + +#### add_message(content: str, file_ids: Optional[List[str]] = None) -> None +Adds a message to the current conversation thread. + +## Error Handling + +The assistant implements robust error handling: +- Retries on rate limits +- Graceful handling of API errors +- Clear error messages for debugging +- Status monitoring for runs and completions + +## Best Practices + +1. Thread Management + - Use the same assistant instance for related conversations + - Create new instances for unrelated tasks + - Monitor thread status during long-running operations + +2. Function Integration + - Keep functions simple and focused + - Provide clear descriptions and parameter schemas + - Handle errors gracefully in custom functions + - Test functions independently before integration + +3. Performance + - Reuse assistant instances when possible + - Monitor and handle rate limits appropriately + - Use appropriate polling intervals for status checks + - Consider implementing timeouts for long-running operations + +## References + +- [OpenAI Assistants API Documentation](https://platform.openai.com/docs/assistants/overview) +- [OpenAI Function Calling Guide](https://platform.openai.com/docs/guides/function-calling) +- [OpenAI Rate Limits](https://platform.openai.com/docs/guides/rate-limits) + + + diff --git a/docs/swarms/install/install.md b/docs/swarms/install/install.md index f69a09bd..9d52d84e 100644 --- a/docs/swarms/install/install.md +++ b/docs/swarms/install/install.md @@ -127,7 +127,7 @@ Before you begin, ensure you have the following installed: poetry install --extras "desktop" ``` -=== "Using Docker" +=== "Using Docker COMING SOON [DOES NOT WORK YET]" Docker is an excellent option for creating isolated and reproducible environments, suitable for both development and production. diff --git a/docs/swarms/structs/async_workflow.md b/docs/swarms/structs/async_workflow.md index 4bb1471c..4f9657e7 100644 --- a/docs/swarms/structs/async_workflow.md +++ b/docs/swarms/structs/async_workflow.md @@ -203,4 +203,64 @@ await workflow.add(tasks=[task_1, task_2]) # Running the workflow results = await workflow.run() print(results) # Output: ["Task 1 Completed", "Task 2 Completed"] -``` \ No newline at end of file +``` + +# Async Workflow + +The AsyncWorkflow allows multiple agents to process tasks concurrently using Python's asyncio framework. + +## Usage Example + +```python +import asyncio +from swarms import Agent, AsyncWorkflow +from swarm_models import OpenAIChat + +# Initialize model +model = OpenAIChat( + openai_api_key="your-api-key", + model_name="gpt-4", + temperature=0.7 +) + +# Create agents +agents = [ + Agent( + agent_name=f"Analysis-Agent-{i}", + llm=model, + max_loops=1, + dashboard=False, + verbose=True, + ) + for i in range(3) +] + +# Initialize workflow +workflow = AsyncWorkflow( + name="Analysis-Workflow", + agents=agents, + max_workers=3, + verbose=True +) + +# Run workflow +async def main(): + task = "Analyze the potential impact of AI on healthcare" + results = await workflow.run(task) + for i, result in enumerate(results): + print(f"Agent {i} result: {result}") + +# Execute +asyncio.run(main()) +``` + +## Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | str | "AsyncWorkflow" | Name of the workflow | +| `agents` | List[Agent] | None | List of agents to execute tasks | +| `max_workers` | int | 5 | Maximum number of concurrent workers | +| `dashboard` | bool | False | Enable/disable dashboard | +| `autosave` | bool | False | Enable/disable autosaving results | +| `verbose` | bool | False | Enable/disable verbose logging | \ No newline at end of file diff --git a/ethereum_analysis.csv b/ethereum_analysis.csv deleted file mode 100644 index 703c4fbe..00000000 --- a/ethereum_analysis.csv +++ /dev/null @@ -1,4 +0,0 @@ -timestamp,transaction_hash,from_address,to_address,value_eth,gas_used,gas_price_gwei,block_number,analysis -2024-11-27T13:50:35,ddbb665bc75fe848e7ce3d3ce1729243e92466c38ca407deccce8bf629987652,0x267be1C1D684F78cb4F6a176C4911b741E4Ffdc0,0xa40dFEE99E1C85DC97Fdc594b16A460717838703,3200.0,21000,19.968163737,21281878,"Transaction Analysis: This transaction represents a significant transfer of value in the Ethereum network with 3200 ETH (~$6.72 million USD at the current rate) moved from one address to another. It is essential to note that this transaction did not involve smart contract interaction, suggesting it could be a straightforward transfer of funds rather than part of a more complex operation. Looking at the broader market context, large transactions like this can potentially indicate major investment activities or redistribution of assets, which can have ripple effects in the market. If this transaction is part of a larger pattern of significant transfers, it could suggest substantial liquidity moving in the Ethereum ecosystem, possibly affecting the ETH prices. From a DeFi point of view, since there's no contract interaction, it's difficult to infer any direct implications. However, given the substantial value involved, it could be a step in preparation for involvement in DeFi protocols or a move from one DeFi platform to another by a large investor. The transaction fee paid, calculated from the given Gas Used and Gas Price, appears to be within reasonable range. This suggests that the transaction was not rushed and that the sender was willing to wait for this transaction to be confirmed, which might hint towards the non-urgent nature of the transaction. As for potential risk factors or security concerns, the transaction itself appears to be standard and doesn't raise any immediate red flags. However, the parties involved should always be cautious about the address security, maintaining privacy, and avoiding social engineering attacks. For traders and investors, this transaction can be interpreted as a potential bullish sign if it signifies increased liquidity and investment in the Ethereum market, especially if it's followed by similar large transfers. However, due to the anonymous nature of the transaction, it's critical to combine this with other market indicators and not to rely solely on transaction analysis for investment decisions." -2024-11-27T13:52:23,b98bcbf6d57a158b67a126d8f023766e03fb15c3e74becc1189d4244fda61a13,0xEae7380dD4CeF6fbD1144F49E4D1e6964258A4F4,0x28C6c06298d514Db089934071355E5743bf21d60,401.99463589018103,21000,14.978063737,21281887,"Ethereum-Analysis-Agent: Transaction Analysis: This transaction marks a significant transfer of 401.99 ETH, approximately $845,000 at the current rate. The transaction did not involve any smart contract interaction, suggesting a simple fund transfer rather than a complicated operation or interaction with a DeFi protocol. From a broader market perspective, this transaction is meaningful but not as potentially impactful as larger transactions. It can nonetheless be part of a larger pattern of asset movement within the Ethereum ecosystem. If this transaction is part of larger investment activities, it could suggest an increase in demand for ETH and potentially impact its price. Without contract interaction, it's challenging to assess direct implications for DeFi protocols. However, the substantial ETH transfer could suggest a step towards participation in DeFi activities, or a movement of funds between different DeFi platforms. The transaction fee appears reasonable, given the Gas Used and Gas Price. This implies that the transaction wasn't urgent, and the sender was willing to wait for the transaction to be confirmed, indicating a non-critical movement of funds. In terms of security and risk factors, there are no immediate concerns from the transaction itself. Nevertheless, as with any crypto transaction, the parties involved should ensure secure storage of their keys, maintain privacy, and be wary of potential phishing or social engineering attacks. For traders and investors, this transaction could be seen as a bullish sign if it forms part of a trend of increased investment activities in the Ethereum market. However, it's important to remember that transaction analysis should be combined with other market indicators due to the anonymous nature of blockchain transactions." -2024-11-27T13:59:47,a985b74fd3dfee09cbe4a2e6890509e583a3f0ce13f68c98e82996e0f66428be,0xf7858Da8a6617f7C6d0fF2bcAFDb6D2eeDF64840,0xA294cCa691e4C83B1fc0c8d63D9a3eeF0A196DE1,136.0668,494665.408728,3635.46,21000,18.866443971,21281923,"1. MARKET CONTEXT The transaction of 136.07 ETH, equivalent to $494,665.41, is a significant movement in the Ethereum market. However, compared to the daily trading volume of Ethereum, which often exceeds billions of dollars, this transaction is not large enough to significantly impact the ETH price on its own. 2. BEHAVIORAL ANALYSIS The transaction does not appear to be a protocol movement as there is no contract interaction involved. It could be a whale movement, given the substantial amount of ETH transferred. However, without additional information about the wallets involved, it's difficult to definitively determine the nature of the transaction. The gas price of 18.87 Gwei is relatively standard, suggesting that the transaction was not urgent or time-sensitive. 3. RISK & IMPLICATIONS The transaction does not show signs of market manipulation or unusual activity. The absence of contract interaction suggests that this transaction does not directly involve DeFi protocols, reducing the risk of smart contract vulnerabilities or DeFi-related risks. However, the large amount of ETH transferred could potentially influence market sentiment if it is part of a larger trend of similar transactions. 4. STRATEGIC INSIGHTS Traders should note this transaction as part of the broader market activity. While a single transaction of this size is unlikely to significantly impact the market, a series of similar transactions could indicate a larger trend. If this is part of a larger movement of ETH out of exchanges, it could suggest a decrease in selling pressure, which could be bullish for ETH. Conversely, if this is part of a larger movement into exchanges, it could indicate an increase in selling pressure, which could be bearish for ETH. Traders should monitor the market for further similar transactions to gain a better understanding of the potential market trends." diff --git a/async_agents.py b/new_features_examples/async_agents.py similarity index 95% rename from async_agents.py rename to new_features_examples/async_agents.py index 0d9353db..8734cd8a 100644 --- a/async_agents.py +++ b/new_features_examples/async_agents.py @@ -7,7 +7,7 @@ from swarms import Agent from swarms.prompts.finance_agent_sys_prompt import ( FINANCIAL_AGENT_SYS_PROMPT, ) -from async_executor import HighSpeedExecutor +from new_features_examples.async_executor import HighSpeedExecutor load_dotenv() diff --git a/async_executor.py b/new_features_examples/async_executor.py similarity index 100% rename from async_executor.py rename to new_features_examples/async_executor.py diff --git a/new_features_examples/auto_agent.py b/new_features_examples/auto_agent.py new file mode 100644 index 00000000..712be089 --- /dev/null +++ b/new_features_examples/auto_agent.py @@ -0,0 +1,188 @@ +import json +import os +from contextlib import suppress +from typing import Any, Callable, Dict, Optional, Type, Union + +from dotenv import load_dotenv +from pydantic import BaseModel, Field, ValidationError, create_model +from swarm_models.openai_function_caller import OpenAIFunctionCaller + + +class DynamicParser: + @staticmethod + def extract_fields(model: Type[BaseModel]) -> Dict[str, Any]: + return { + field_name: (field.annotation, ... if field.is_required() else None) + for field_name, field in model.model_fields.items() + } + + @staticmethod + def create_partial_model(model: Type[BaseModel], data: Dict[str, Any]) -> Type[BaseModel]: + fields = { + field_name: (field.annotation, ... if field.is_required() else None) + for field_name, field in model.model_fields.items() + if field_name in data + } + return create_model(f"Partial{model.__name__}", **fields) + + @classmethod + def parse(cls, data: Union[str, Dict[str, Any]], model: Type[BaseModel]) -> Optional[BaseModel]: + if isinstance(data, str): + try: + data = json.loads(data) + except json.JSONDecodeError: + return None + + # Try full model first + with suppress(ValidationError): + return model.model_validate(data) + + # Create and try partial model + partial_model = cls.create_partial_model(model, data) + with suppress(ValidationError): + return partial_model.model_validate(data) + + return None + + +load_dotenv() + +# Define the Thoughts schema +class Thoughts(BaseModel): + text: str = Field(..., description="Current thoughts or observations regarding the task.") + reasoning: str = Field(..., description="Logical reasoning behind the thought process.") + plan: str = Field(..., description="A short bulleted list that conveys the immediate and long-term plan.") + criticism: str = Field(..., description="Constructive self-criticism to improve future responses.") + speak: str = Field(..., description="A concise summary of thoughts intended for the user.") + +# Define the Command schema +class Command(BaseModel): + name: str = Field(..., description="Command name to execute from the provided list of commands.") + args: Dict[str, Any] = Field(..., description="Arguments required to execute the command.") + +# Define the AgentResponse schema +class AgentResponse(BaseModel): + thoughts: Thoughts = Field(..., description="The agent's current thoughts and reasoning.") + command: Command = Field(..., description="The command to execute along with its arguments.") + + + +# Define tool functions +def fluid_api_command(task: str): + """Execute a fluid API request.""" + # response = fluid_api_request(task) + print(response.model_dump_json(indent=4)) + return response + + +def send_tweet_command(text: str): + """Simulate sending a tweet.""" + print(f"Tweet sent: {text}") + return {"status": "success", "message": f"Tweet sent: {text}"} + + +def do_nothing_command(): + """Do nothing.""" + print("Doing nothing...") + return {"status": "success", "message": "No action taken."} + + +def task_complete_command(reason: str): + """Mark the task as complete and provide a reason.""" + print(f"Task completed: {reason}") + return {"status": "success", "message": f"Task completed: {reason}"} + + +# Dynamic command execution +def execute_command(name: str, args: Dict[str, Any]): + """Dynamically execute a command based on its name and arguments.""" + command_map: Dict[str, Callable] = { + "fluid_api": lambda **kwargs: fluid_api_command(task=kwargs.get("task")), + "send_tweet": lambda **kwargs: send_tweet_command(text=kwargs.get("text")), + "do_nothing": lambda **kwargs: do_nothing_command(), + "task_complete": lambda **kwargs: task_complete_command(reason=kwargs.get("reason")), + } + + if name not in command_map: + raise ValueError(f"Unknown command: {name}") + + # Execute the command with the provided arguments + return command_map[name](**args) + + +def parse_and_execute_command(response: Union[str, Dict[str, Any]], base_model: Type[BaseModel] = AgentResponse) -> Any: + """Enhanced command parser with flexible input handling""" + parsed = DynamicParser.parse(response, base_model) + if not parsed: + raise ValueError("Failed to parse response") + + if hasattr(parsed, 'command'): + command_name = parsed.command.name + command_args = parsed.command.args + return execute_command(command_name, command_args) + + return parsed + + +ainame = "AutoAgent" +userprovided = "assistant" + +SYSTEM_PROMPT = f""" +You are {ainame}, an advanced and autonomous {userprovided}. +Your role is to make decisions and complete tasks independently without seeking user assistance. Leverage your strengths as an LLM to solve tasks efficiently, adhering strictly to the commands and resources provided. + +### GOALS: +1. {userprovided} +2. Execute tasks with precision and efficiency. +3. Ensure outputs are actionable and aligned with the user's objectives. +4. Continuously optimize task strategies for maximum effectiveness. +5. Maintain reliability and consistency in all responses. + +### CONSTRAINTS: +1. Memory limit: ~4000 words for short-term memory. Save essential information to files immediately to avoid loss. +2. Independent decision-making: Do not rely on user assistance. +3. Exclusively use commands in double quotes (e.g., "command name"). +4. Use subprocesses for commands that may take longer than a few minutes. +5. Ensure all outputs strictly adhere to the specified JSON response format. + +### COMMANDS: +1. Fluid API: "fluid_api", args: "method": "", "url": "", "headers": "", "body": "" +18. Send Tweet: "send_tweet", args: "text": "" +19. Do Nothing: "do_nothing", args: +20. Task Complete (Shutdown): "task_complete", args: "reason": "" + +### RESOURCES: +1. Internet access for real-time information and data gathering. +2. Long-term memory management for storing critical information. +3. Access to GPT-3.5-powered Agents for delegating tasks. +4. File handling capabilities for output storage and retrieval. + +### PERFORMANCE EVALUATION: +1. Continuously analyze and reflect on actions to ensure optimal task completion. +2. Self-critique decisions and strategies constructively to identify areas for improvement. +3. Ensure every command serves a clear purpose and minimizes resource usage. +4. Complete tasks in the least number of steps, balancing speed and accuracy. + +### RESPONSE FORMAT: +Always respond in a strict JSON format as described below. Ensure your responses can be parsed with Python's `json.loads`: +""" + +# Initialize the OpenAIFunctionCaller +model = OpenAIFunctionCaller( + system_prompt=SYSTEM_PROMPT, + max_tokens=4000, + temperature=0.9, + base_model=AgentResponse, # Pass the Pydantic schema as the base model + parallel_tool_calls=False, + openai_api_key=os.getenv("OPENAI_API_KEY") +) + +# Example usage +user_input = ( + "Analyze the provided Python code for inefficiencies, generate suggestions for improvements, " + "and provide optimized code." +) + +response = model.run(user_input) +response = parse_and_execute_command(response) +print(response) diff --git a/concurrent_mix.py b/new_features_examples/concurrent_mix.py similarity index 100% rename from concurrent_mix.py rename to new_features_examples/concurrent_mix.py diff --git a/dict_to_table.py b/new_features_examples/dict_to_table.py similarity index 100% rename from dict_to_table.py rename to new_features_examples/dict_to_table.py diff --git a/ethchain_agent.py b/new_features_examples/ethchain_agent.py similarity index 100% rename from ethchain_agent.py rename to new_features_examples/ethchain_agent.py diff --git a/new_features_examples/microstructure.py b/new_features_examples/microstructure.py new file mode 100644 index 00000000..c13d2e3f --- /dev/null +++ b/new_features_examples/microstructure.py @@ -0,0 +1,1074 @@ +import os +import threading +import time +from collections import deque +from dataclasses import dataclass +from datetime import datetime +from queue import Queue +from typing import Any, Dict, List, Optional, Tuple + +import ccxt +import numpy as np +import pandas as pd +from dotenv import load_dotenv +from loguru import logger +from scipy import stats +from swarm_models import OpenAIChat + +from swarms import Agent + +logger.enable("") + + +@dataclass +class MarketSignal: + timestamp: datetime + signal_type: str + source: str + data: Dict[str, Any] + confidence: float + metadata: Dict[str, Any] + + +class MarketDataBuffer: + def __init__(self, max_size: int = 10000): + self.max_size = max_size + self.data = deque(maxlen=max_size) + self.lock = threading.Lock() + + def add(self, item: Any) -> None: + with self.lock: + self.data.append(item) + + def get_latest(self, n: int = None) -> List[Any]: + with self.lock: + if n is None: + return list(self.data) + return list(self.data)[-n:] + + +class SignalCSVWriter: + def __init__(self, output_dir: str = "market_data"): + self.output_dir = output_dir + self.ensure_output_dir() + self.files = {} + + def ensure_output_dir(self): + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + + def get_filename(self, signal_type: str, symbol: str) -> str: + date_str = datetime.now().strftime("%Y%m%d") + return ( + f"{self.output_dir}/{signal_type}_{symbol}_{date_str}.csv" + ) + + def write_order_book_signal(self, signal: MarketSignal): + symbol = signal.data["symbol"] + metrics = signal.data["metrics"] + filename = self.get_filename("order_book", symbol) + + # Create header if file doesn't exist + if not os.path.exists(filename): + header = [ + "timestamp", + "symbol", + "bid_volume", + "ask_volume", + "mid_price", + "bid_vwap", + "ask_vwap", + "spread", + "depth_imbalance", + "confidence", + ] + with open(filename, "w") as f: + f.write(",".join(header) + "\n") + + # Write data + data = [ + str(signal.timestamp), + symbol, + str(metrics["bid_volume"]), + str(metrics["ask_volume"]), + str(metrics["mid_price"]), + str(metrics["bid_vwap"]), + str(metrics["ask_vwap"]), + str(metrics["spread"]), + str(metrics["depth_imbalance"]), + str(signal.confidence), + ] + + with open(filename, "a") as f: + f.write(",".join(data) + "\n") + + def write_tick_signal(self, signal: MarketSignal): + symbol = signal.data["symbol"] + metrics = signal.data["metrics"] + filename = self.get_filename("tick_data", symbol) + + if not os.path.exists(filename): + header = [ + "timestamp", + "symbol", + "vwap", + "price_momentum", + "volume_mean", + "trade_intensity", + "kyle_lambda", + "roll_spread", + "confidence", + ] + with open(filename, "w") as f: + f.write(",".join(header) + "\n") + + data = [ + str(signal.timestamp), + symbol, + str(metrics["vwap"]), + str(metrics["price_momentum"]), + str(metrics["volume_mean"]), + str(metrics["trade_intensity"]), + str(metrics["kyle_lambda"]), + str(metrics["roll_spread"]), + str(signal.confidence), + ] + + with open(filename, "a") as f: + f.write(",".join(data) + "\n") + + def write_arbitrage_signal(self, signal: MarketSignal): + if ( + "best_opportunity" not in signal.data + or not signal.data["best_opportunity"] + ): + return + + symbol = signal.data["symbol"] + opp = signal.data["best_opportunity"] + filename = self.get_filename("arbitrage", symbol) + + if not os.path.exists(filename): + header = [ + "timestamp", + "symbol", + "buy_venue", + "sell_venue", + "spread", + "return", + "buy_price", + "sell_price", + "confidence", + ] + with open(filename, "w") as f: + f.write(",".join(header) + "\n") + + data = [ + str(signal.timestamp), + symbol, + opp["buy_venue"], + opp["sell_venue"], + str(opp["spread"]), + str(opp["return"]), + str(opp["buy_price"]), + str(opp["sell_price"]), + str(signal.confidence), + ] + + with open(filename, "a") as f: + f.write(",".join(data) + "\n") + + +class ExchangeManager: + def __init__(self): + self.available_exchanges = { + "kraken": ccxt.kraken, + "coinbase": ccxt.coinbase, + "kucoin": ccxt.kucoin, + "bitfinex": ccxt.bitfinex, + "gemini": ccxt.gemini, + } + self.active_exchanges = {} + self.test_exchanges() + + def test_exchanges(self): + """Test each exchange and keep only the accessible ones""" + for name, exchange_class in self.available_exchanges.items(): + try: + exchange = exchange_class() + exchange.load_markets() + self.active_exchanges[name] = exchange + logger.info(f"Successfully connected to {name}") + except Exception as e: + logger.warning(f"Could not connect to {name}: {e}") + + def get_primary_exchange(self) -> Optional[ccxt.Exchange]: + """Get the first available exchange""" + if not self.active_exchanges: + raise RuntimeError("No exchanges available") + return next(iter(self.active_exchanges.values())) + + def get_all_active_exchanges(self) -> Dict[str, ccxt.Exchange]: + """Get all active exchanges""" + return self.active_exchanges + + +class BaseMarketAgent(Agent): + def __init__( + self, + agent_name: str, + system_prompt: str, + api_key: str, + model_name: str = "gpt-4-0125-preview", + temperature: float = 0.1, + ): + model = OpenAIChat( + openai_api_key=api_key, + model_name=model_name, + temperature=temperature, + ) + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + llm=model, + max_loops=1, + autosave=True, + dashboard=False, + verbose=True, + dynamic_temperature_enabled=True, + context_length=200000, + streaming_on=True, + output_type="str", + ) + self.signal_queue = Queue() + self.is_running = False + self.last_update = datetime.now() + self.update_interval = 1.0 # seconds + + def rate_limit_check(self) -> bool: + current_time = datetime.now() + if ( + current_time - self.last_update + ).total_seconds() < self.update_interval: + return False + self.last_update = current_time + return True + + +class OrderBookAgent(BaseMarketAgent): + def __init__(self, api_key: str): + system_prompt = """ + You are an Order Book Analysis Agent specialized in detecting institutional flows. + Monitor order book depth and changes to identify potential large trades and institutional activity. + Analyze patterns in order placement and cancellation rates. + """ + super().__init__("OrderBookAgent", system_prompt, api_key) + exchange_manager = ExchangeManager() + self.exchange = exchange_manager.get_primary_exchange() + self.order_book_buffer = MarketDataBuffer(max_size=100) + self.vwap_window = 20 + + def calculate_order_book_metrics( + self, order_book: Dict + ) -> Dict[str, float]: + bids = np.array(order_book["bids"]) + asks = np.array(order_book["asks"]) + + # Calculate key metrics + bid_volume = np.sum(bids[:, 1]) + ask_volume = np.sum(asks[:, 1]) + mid_price = (bids[0][0] + asks[0][0]) / 2 + + # Calculate VWAP + bid_vwap = ( + np.sum( + bids[: self.vwap_window, 0] + * bids[: self.vwap_window, 1] + ) + / bid_volume + if bid_volume > 0 + else 0 + ) + ask_vwap = ( + np.sum( + asks[: self.vwap_window, 0] + * asks[: self.vwap_window, 1] + ) + / ask_volume + if ask_volume > 0 + else 0 + ) + + # Calculate order book slope + bid_slope = np.polyfit( + range(len(bids[:10])), bids[:10, 0], 1 + )[0] + ask_slope = np.polyfit( + range(len(asks[:10])), asks[:10, 0], 1 + )[0] + + return { + "bid_volume": bid_volume, + "ask_volume": ask_volume, + "mid_price": mid_price, + "bid_vwap": bid_vwap, + "ask_vwap": ask_vwap, + "bid_slope": bid_slope, + "ask_slope": ask_slope, + "spread": asks[0][0] - bids[0][0], + "depth_imbalance": (bid_volume - ask_volume) + / (bid_volume + ask_volume), + } + + def detect_large_orders( + self, metrics: Dict[str, float], threshold: float = 2.0 + ) -> bool: + historical_books = self.order_book_buffer.get_latest(20) + if not historical_books: + return False + + # Calculate historical volume statistics + hist_volumes = [ + book["bid_volume"] + book["ask_volume"] + for book in historical_books + ] + volume_mean = np.mean(hist_volumes) + volume_std = np.std(hist_volumes) + + current_volume = metrics["bid_volume"] + metrics["ask_volume"] + z_score = (current_volume - volume_mean) / ( + volume_std if volume_std > 0 else 1 + ) + + return abs(z_score) > threshold + + def analyze_order_book(self, symbol: str) -> MarketSignal: + if not self.rate_limit_check(): + return None + + try: + order_book = self.exchange.fetch_order_book( + symbol, limit=100 + ) + metrics = self.calculate_order_book_metrics(order_book) + self.order_book_buffer.add(metrics) + + # Format data for LLM analysis + analysis_prompt = f""" + Analyze this order book for {symbol}: + Bid Volume: {metrics['bid_volume']} + Ask Volume: {metrics['ask_volume']} + Mid Price: {metrics['mid_price']} + Spread: {metrics['spread']} + Depth Imbalance: {metrics['depth_imbalance']} + + What patterns do you see? Is there evidence of institutional activity? + Are there any significant imbalances that could lead to price movement? + """ + + # Get LLM analysis + llm_analysis = self.run(analysis_prompt) + + # Original signal creation with added LLM analysis + return MarketSignal( + timestamp=datetime.now(), + signal_type="order_book_analysis", + source="OrderBookAgent", + data={ + "metrics": metrics, + "large_order_detected": self.detect_large_orders( + metrics + ), + "symbol": symbol, + "llm_analysis": llm_analysis, # Add LLM insights + }, + confidence=min( + abs(metrics["depth_imbalance"]) * 0.7 + + ( + 1.0 + if self.detect_large_orders(metrics) + else 0.0 + ) + * 0.3, + 1.0, + ), + metadata={ + "update_latency": ( + datetime.now() - self.last_update + ).total_seconds(), + "buffer_size": len( + self.order_book_buffer.get_latest() + ), + }, + ) + except Exception as e: + logger.error(f"Error in order book analysis: {str(e)}") + return None + + +class TickDataAgent(BaseMarketAgent): + def __init__(self, api_key: str): + system_prompt = """ + You are a Tick Data Analysis Agent specialized in analyzing high-frequency price movements. + Monitor tick-by-tick data for patterns indicating short-term price direction. + Analyze trade size distribution and execution speed. + """ + super().__init__("TickDataAgent", system_prompt, api_key) + self.tick_buffer = MarketDataBuffer(max_size=5000) + exchange_manager = ExchangeManager() + self.exchange = exchange_manager.get_primary_exchange() + + def calculate_tick_metrics( + self, ticks: List[Dict] + ) -> Dict[str, float]: + df = pd.DataFrame(ticks) + df["price"] = pd.to_numeric(df["price"]) + df["volume"] = pd.to_numeric(df["amount"]) + + # Calculate key metrics + metrics = {} + + # Volume-weighted average price (VWAP) + metrics["vwap"] = (df["price"] * df["volume"]).sum() / df[ + "volume" + ].sum() + + # Price momentum + metrics["price_momentum"] = df["price"].diff().mean() + + # Volume profile + metrics["volume_mean"] = df["volume"].mean() + metrics["volume_std"] = df["volume"].std() + + # Trade intensity + time_diff = ( + df["timestamp"].max() - df["timestamp"].min() + ) / 1000 # Convert to seconds + metrics["trade_intensity"] = ( + len(df) / time_diff if time_diff > 0 else 0 + ) + + # Microstructure indicators + metrics["kyle_lambda"] = self.calculate_kyle_lambda(df) + metrics["roll_spread"] = self.calculate_roll_spread(df) + + return metrics + + def calculate_kyle_lambda(self, df: pd.DataFrame) -> float: + """Calculate Kyle's Lambda (price impact coefficient)""" + try: + price_changes = df["price"].diff().dropna() + volume_changes = df["volume"].diff().dropna() + + if len(price_changes) > 1 and len(volume_changes) > 1: + slope, _, _, _, _ = stats.linregress( + volume_changes, price_changes + ) + return abs(slope) + except Exception as e: + logger.warning(f"Error calculating Kyle's Lambda: {e}") + return 0.0 + + def calculate_roll_spread(self, df: pd.DataFrame) -> float: + """Calculate Roll's implied spread""" + try: + price_changes = df["price"].diff().dropna() + if len(price_changes) > 1: + autocov = np.cov( + price_changes[:-1], price_changes[1:] + )[0][1] + return 2 * np.sqrt(-autocov) if autocov < 0 else 0.0 + except Exception as e: + logger.warning(f"Error calculating Roll spread: {e}") + return 0.0 + + def calculate_tick_metrics( + self, ticks: List[Dict] + ) -> Dict[str, float]: + try: + # Debug the incoming data structure + logger.info( + f"Raw tick data structure: {ticks[0] if ticks else 'No ticks'}" + ) + + # Convert trades to proper format + formatted_trades = [] + for trade in ticks: + formatted_trade = { + "price": float( + trade.get("price", trade.get("last", 0)) + ), # Handle different exchange formats + "amount": float( + trade.get( + "amount", + trade.get( + "size", trade.get("quantity", 0) + ), + ) + ), + "timestamp": trade.get( + "timestamp", int(time.time() * 1000) + ), + } + formatted_trades.append(formatted_trade) + + df = pd.DataFrame(formatted_trades) + + if df.empty: + logger.warning("No valid trades to analyze") + return { + "vwap": 0.0, + "price_momentum": 0.0, + "volume_mean": 0.0, + "volume_std": 0.0, + "trade_intensity": 0.0, + "kyle_lambda": 0.0, + "roll_spread": 0.0, + } + + # Calculate metrics with the properly formatted data + metrics = {} + metrics["vwap"] = ( + (df["price"] * df["amount"]).sum() + / df["amount"].sum() + if not df.empty + else 0 + ) + metrics["price_momentum"] = ( + df["price"].diff().mean() if len(df) > 1 else 0 + ) + metrics["volume_mean"] = df["amount"].mean() + metrics["volume_std"] = df["amount"].std() + + time_diff = ( + (df["timestamp"].max() - df["timestamp"].min()) / 1000 + if len(df) > 1 + else 1 + ) + metrics["trade_intensity"] = ( + len(df) / time_diff if time_diff > 0 else 0 + ) + + metrics["kyle_lambda"] = self.calculate_kyle_lambda(df) + metrics["roll_spread"] = self.calculate_roll_spread(df) + + logger.info(f"Calculated metrics: {metrics}") + return metrics + + except Exception as e: + logger.error( + f"Error in calculate_tick_metrics: {str(e)}", + exc_info=True, + ) + # Return default metrics on error + return { + "vwap": 0.0, + "price_momentum": 0.0, + "volume_mean": 0.0, + "volume_std": 0.0, + "trade_intensity": 0.0, + "kyle_lambda": 0.0, + "roll_spread": 0.0, + } + + def analyze_ticks(self, symbol: str) -> MarketSignal: + if not self.rate_limit_check(): + return None + + try: + # Fetch recent trades + trades = self.exchange.fetch_trades(symbol, limit=100) + + # Debug the raw trades data + logger.info(f"Fetched {len(trades)} trades for {symbol}") + if trades: + logger.info(f"Sample trade: {trades[0]}") + + self.tick_buffer.add(trades) + recent_ticks = self.tick_buffer.get_latest(1000) + metrics = self.calculate_tick_metrics(recent_ticks) + + # Only proceed with LLM analysis if we have valid metrics + if metrics["vwap"] > 0: + analysis_prompt = f""" + Analyze these trading patterns for {symbol}: + VWAP: {metrics['vwap']:.2f} + Price Momentum: {metrics['price_momentum']:.2f} + Trade Intensity: {metrics['trade_intensity']:.2f} + Kyle's Lambda: {metrics['kyle_lambda']:.2f} + + What does this tell us about: + 1. Current market sentiment + 2. Potential price direction + 3. Trading activity patterns + """ + llm_analysis = self.run(analysis_prompt) + else: + llm_analysis = "Insufficient data for analysis" + + return MarketSignal( + timestamp=datetime.now(), + signal_type="tick_analysis", + source="TickDataAgent", + data={ + "metrics": metrics, + "symbol": symbol, + "prediction": np.sign(metrics["price_momentum"]), + "llm_analysis": llm_analysis, + }, + confidence=min(metrics["trade_intensity"] / 100, 1.0) + * 0.4 + + min(metrics["kyle_lambda"], 1.0) * 0.6, + metadata={ + "update_latency": ( + datetime.now() - self.last_update + ).total_seconds(), + "buffer_size": len(self.tick_buffer.get_latest()), + }, + ) + + except Exception as e: + logger.error( + f"Error in tick analysis: {str(e)}", exc_info=True + ) + return None + + +class LatencyArbitrageAgent(BaseMarketAgent): + def __init__(self, api_key: str): + system_prompt = """ + You are a Latency Arbitrage Agent specialized in detecting price discrepancies across venues. + Monitor multiple exchanges for price differences exceeding transaction costs. + Calculate optimal trade sizes and routes. + """ + super().__init__( + "LatencyArbitrageAgent", system_prompt, api_key + ) + exchange_manager = ExchangeManager() + self.exchanges = exchange_manager.get_all_active_exchanges() + self.fee_structure = { + "kraken": 0.0026, # 0.26% taker fee + "coinbase": 0.006, # 0.6% taker fee + "kucoin": 0.001, # 0.1% taker fee + "bitfinex": 0.002, # 0.2% taker fee + "gemini": 0.003, # 0.3% taker fee + } + self.price_buffer = { + ex: MarketDataBuffer(max_size=100) + for ex in self.exchanges + } + + def calculate_effective_prices( + self, ticker: Dict, venue: str + ) -> Tuple[float, float]: + """Calculate effective prices including fees""" + fee = self.fee_structure[venue] + return ( + ticker["bid"] * (1 - fee), # Effective sell price + ticker["ask"] * (1 + fee), # Effective buy price + ) + + def calculate_arbitrage_metrics( + self, prices: Dict[str, Dict] + ) -> Dict: + opportunities = [] + + for venue1 in prices: + for venue2 in prices: + if venue1 != venue2: + sell_price, _ = self.calculate_effective_prices( + prices[venue1], venue1 + ) + _, buy_price = self.calculate_effective_prices( + prices[venue2], venue2 + ) + + spread = sell_price - buy_price + if spread > 0: + opportunities.append( + { + "sell_venue": venue1, + "buy_venue": venue2, + "spread": spread, + "return": spread / buy_price, + "buy_price": buy_price, + "sell_price": sell_price, + } + ) + + return { + "opportunities": opportunities, + "best_opportunity": ( + max(opportunities, key=lambda x: x["return"]) + if opportunities + else None + ), + } + + def find_arbitrage(self, symbol: str) -> MarketSignal: + """ + Find arbitrage opportunities across exchanges with LLM analysis + """ + if not self.rate_limit_check(): + return None + + try: + prices = {} + timestamps = {} + + for name, exchange in self.exchanges.items(): + try: + ticker = exchange.fetch_ticker(symbol) + prices[name] = { + "bid": ticker["bid"], + "ask": ticker["ask"], + } + timestamps[name] = ticker["timestamp"] + self.price_buffer[name].add(prices[name]) + except Exception as e: + logger.warning( + f"Error fetching {name} price: {e}" + ) + + if len(prices) < 2: + return None + + metrics = self.calculate_arbitrage_metrics(prices) + + if not metrics["best_opportunity"]: + return None + + # Calculate confidence based on spread and timing + opp = metrics["best_opportunity"] + timing_factor = 1.0 - min( + abs( + timestamps[opp["sell_venue"]] + - timestamps[opp["buy_venue"]] + ) + / 1000, + 1.0, + ) + spread_factor = min( + opp["return"] * 5, 1.0 + ) # Scale return to confidence + + confidence = timing_factor * 0.4 + spread_factor * 0.6 + + # Format price data for LLM analysis + price_summary = "\n".join( + [ + f"{venue}: Bid ${prices[venue]['bid']:.2f}, Ask ${prices[venue]['ask']:.2f}" + for venue in prices.keys() + ] + ) + + # Create detailed analysis prompt + analysis_prompt = f""" + Analyze this arbitrage opportunity for {symbol}: + + Current Prices: + {price_summary} + + Best Opportunity Found: + Buy Venue: {opp['buy_venue']} at ${opp['buy_price']:.2f} + Sell Venue: {opp['sell_venue']} at ${opp['sell_price']:.2f} + Spread: ${opp['spread']:.2f} + Expected Return: {opp['return']*100:.3f}% + Time Difference: {abs(timestamps[opp['sell_venue']] - timestamps[opp['buy_venue']])}ms + + Consider: + 1. Is this opportunity likely to be profitable after execution costs? + 2. What risks might prevent successful execution? + 3. What market conditions might have created this opportunity? + 4. How does the timing difference affect execution probability? + """ + + # Get LLM analysis + llm_analysis = self.run(analysis_prompt) + + # Create comprehensive signal + return MarketSignal( + timestamp=datetime.now(), + signal_type="arbitrage_opportunity", + source="LatencyArbitrageAgent", + data={ + "metrics": metrics, + "symbol": symbol, + "best_opportunity": metrics["best_opportunity"], + "all_prices": prices, + "llm_analysis": llm_analysis, + "timing": { + "time_difference_ms": abs( + timestamps[opp["sell_venue"]] + - timestamps[opp["buy_venue"]] + ), + "timestamps": timestamps, + }, + }, + confidence=confidence, + metadata={ + "update_latency": ( + datetime.now() - self.last_update + ).total_seconds(), + "timestamp_deltas": timestamps, + "venue_count": len(prices), + "execution_risk": 1.0 + - timing_factor, # Higher time difference = higher risk + }, + ) + + except Exception as e: + logger.error(f"Error in arbitrage analysis: {str(e)}") + return None + + +class SwarmCoordinator: + def __init__(self, api_key: str): + self.api_key = api_key + self.agents = { + "order_book": OrderBookAgent(api_key), + "tick_data": TickDataAgent(api_key), + "latency_arb": LatencyArbitrageAgent(api_key), + } + self.signal_processors = [] + self.signal_history = MarketDataBuffer(max_size=1000) + self.running = False + self.lock = threading.Lock() + self.csv_writer = SignalCSVWriter() + + def register_signal_processor(self, processor): + """Register a new signal processor function""" + with self.lock: + self.signal_processors.append(processor) + + def process_signals(self, signals: List[MarketSignal]): + """Process signals through all registered processors""" + if not signals: + return + + self.signal_history.add(signals) + + try: + for processor in self.signal_processors: + processor(signals) + except Exception as e: + logger.error(f"Error in signal processing: {e}") + + def aggregate_signals( + self, signals: List[MarketSignal] + ) -> Dict[str, Any]: + """Aggregate multiple signals into a combined market view""" + if not signals: + return {} + + self.signal_history.add(signals) + + aggregated = { + "timestamp": datetime.now(), + "symbols": set(), + "agent_signals": {}, + "combined_confidence": 0, + "market_state": {}, + } + + for signal in signals: + symbol = signal.data.get("symbol") + if symbol: + aggregated["symbols"].add(symbol) + + agent_type = signal.source + if agent_type not in aggregated["agent_signals"]: + aggregated["agent_signals"][agent_type] = [] + aggregated["agent_signals"][agent_type].append(signal) + + # Update market state based on signal type + if signal.signal_type == "order_book_analysis": + metrics = signal.data.get("metrics", {}) + aggregated["market_state"].update( + { + "order_book_imbalance": metrics.get( + "depth_imbalance" + ), + "spread": metrics.get("spread"), + "large_orders_detected": signal.data.get( + "large_order_detected" + ), + } + ) + elif signal.signal_type == "tick_analysis": + metrics = signal.data.get("metrics", {}) + aggregated["market_state"].update( + { + "price_momentum": metrics.get( + "price_momentum" + ), + "trade_intensity": metrics.get( + "trade_intensity" + ), + "kyle_lambda": metrics.get("kyle_lambda"), + } + ) + elif signal.signal_type == "arbitrage_opportunity": + opp = signal.data.get("best_opportunity") + if opp: + aggregated["market_state"].update( + { + "arbitrage_spread": opp.get("spread"), + "arbitrage_return": opp.get("return"), + } + ) + + # Calculate combined confidence as weighted average + confidences = [s.confidence for s in signals] + if confidences: + aggregated["combined_confidence"] = np.mean(confidences) + + return aggregated + + def start(self, symbols: List[str], interval: float = 1.0): + """Start the swarm monitoring system""" + if self.running: + logger.warning("Swarm is already running") + return + + self.running = True + + def agent_loop(agent, symbol): + while self.running: + try: + if isinstance(agent, OrderBookAgent): + signal = agent.analyze_order_book(symbol) + elif isinstance(agent, TickDataAgent): + signal = agent.analyze_ticks(symbol) + elif isinstance(agent, LatencyArbitrageAgent): + signal = agent.find_arbitrage(symbol) + + if signal: + agent.signal_queue.put(signal) + except Exception as e: + logger.error( + f"Error in {agent.agent_name} loop: {e}" + ) + + time.sleep(interval) + + def signal_collection_loop(): + while self.running: + try: + current_signals = [] + + # Collect signals from all agents + for agent in self.agents.values(): + while not agent.signal_queue.empty(): + signal = agent.signal_queue.get_nowait() + if signal: + current_signals.append(signal) + + if current_signals: + # Process current signals + self.process_signals(current_signals) + + # Aggregate and analyze + aggregated = self.aggregate_signals( + current_signals + ) + logger.info( + f"Aggregated market view: {aggregated}" + ) + + except Exception as e: + logger.error( + f"Error in signal collection loop: {e}" + ) + + time.sleep(interval) + + # Start agent threads + self.threads = [] + for symbol in symbols: + for agent in self.agents.values(): + thread = threading.Thread( + target=agent_loop, + args=(agent, symbol), + daemon=True, + ) + thread.start() + self.threads.append(thread) + + # Start signal collection thread + collection_thread = threading.Thread( + target=signal_collection_loop, daemon=True + ) + collection_thread.start() + self.threads.append(collection_thread) + + def stop(self): + """Stop the swarm monitoring system""" + self.running = False + for thread in self.threads: + thread.join(timeout=5.0) + logger.info("Swarm stopped") + + +def market_making_processor(signals: List[MarketSignal]): + """Enhanced signal processor with LLM analysis integration""" + for signal in signals: + if signal.confidence > 0.8: + if signal.signal_type == "arbitrage_opportunity": + opp = signal.data.get("best_opportunity") + if ( + opp and opp["return"] > 0.001 + ): # 0.1% return threshold + logger.info( + "\nSignificant arbitrage opportunity detected:" + ) + logger.info(f"Return: {opp['return']*100:.3f}%") + logger.info(f"Spread: ${opp['spread']:.2f}") + if "llm_analysis" in signal.data: + logger.info("\nLLM Analysis:") + logger.info(signal.data["llm_analysis"]) + + elif signal.signal_type == "order_book_analysis": + imbalance = signal.data["metrics"]["depth_imbalance"] + if abs(imbalance) > 0.3: + logger.info( + f"\nSignificant order book imbalance detected: {imbalance:.3f}" + ) + if "llm_analysis" in signal.data: + logger.info("\nLLM Analysis:") + logger.info(signal.data["llm_analysis"]) + + elif signal.signal_type == "tick_analysis": + momentum = signal.data["metrics"]["price_momentum"] + if abs(momentum) > 0: + logger.info( + f"\nSignificant price momentum detected: {momentum:.3f}" + ) + if "llm_analysis" in signal.data: + logger.info("\nLLM Analysis:") + logger.info(signal.data["llm_analysis"]) + + +load_dotenv() +api_key = os.getenv("OPENAI_API_KEY") + +coordinator = SwarmCoordinator(api_key) +coordinator.register_signal_processor(market_making_processor) + +symbols = ["BTC/USDT", "ETH/USDT"] + +logger.info( + "Starting market microstructure analysis with LLM integration..." +) +logger.info(f"Monitoring symbols: {symbols}") +logger.info( + f"CSV files will be written to: {os.path.abspath('market_data')}" +) + +try: + coordinator.start(symbols) + while True: + time.sleep(1) +except KeyboardInterrupt: + logger.info("Gracefully shutting down...") + coordinator.stop() diff --git a/multi_tool_usage_agent.py b/new_features_examples/multi_tool_usage_agent.py similarity index 99% rename from multi_tool_usage_agent.py rename to new_features_examples/multi_tool_usage_agent.py index 44577528..1af421e2 100644 --- a/multi_tool_usage_agent.py +++ b/new_features_examples/multi_tool_usage_agent.py @@ -1,5 +1,5 @@ import os -from typing import List, Dict, Any, Optional, Callable +from typing import List, Dict, Any, Optional, Callable, get_type_hints from dataclasses import dataclass, field import json from datetime import datetime @@ -111,6 +111,9 @@ class ExecutionContext: history: List[Dict[str, Any]] = field(default_factory=list) +hints = get_type_hints(func) + + class ToolAgent: def __init__( self, diff --git a/rearrange_test.py b/new_features_examples/rearrange_test.py similarity index 100% rename from rearrange_test.py rename to new_features_examples/rearrange_test.py diff --git a/pyproject.toml b/pyproject.toml index 51bb898f..0cc0a373 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "6.3.7" +version = "6.4.7" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -61,7 +61,7 @@ torch = ">=2.1.1,<3.0" transformers = ">= 4.39.0, <5.0.0" asyncio = ">=3.4.3,<4.0" toml = "*" -pypdf = "4.3.1" +pypdf = "5.1.0" loguru = "*" pydantic = "2.8.2" tenacity = "*" @@ -86,7 +86,7 @@ swarms = "swarms.cli.main:main" [tool.poetry.group.lint.dependencies] black = ">=23.1,<25.0" -ruff = ">=0.5.1,<0.7.4" +ruff = ">=0.5.1,<0.8.2" types-toml = "^0.10.8.1" types-pytz = ">=2023.3,<2025.0" types-chardet = "^5.0.4.6" diff --git a/requirements.txt b/requirements.txt index ca9fdcdd..e5375a0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ asyncio>=3.4.3,<4.0 toml pypdf==4.3.1 ratelimit==2.2.1 -loguru==0.7.2 +loguru pydantic==2.8.2 tenacity rich diff --git a/simple_example.py b/simple_example.py new file mode 100644 index 00000000..2fcbb8f9 --- /dev/null +++ b/simple_example.py @@ -0,0 +1,7 @@ +from swarms import Agent + +Agent( + agent_name="Stock-Analysis-Agent", + model_name="gpt-4o-mini", + max_loops=1, +).run("What are 5 hft algorithms") diff --git a/swarms/agents/auto_generate_swarm_config.py b/swarms/agents/auto_generate_swarm_config.py new file mode 100644 index 00000000..febb85e3 --- /dev/null +++ b/swarms/agents/auto_generate_swarm_config.py @@ -0,0 +1,253 @@ +import re + +from dotenv import load_dotenv +from tenacity import retry, stop_after_attempt, wait_exponential + +from swarms import Agent +from swarms.agents.create_agents_from_yaml import ( + create_agents_from_yaml, +) +from swarms.utils.formatter import formatter +from swarms.utils.litellm import LiteLLM + +load_dotenv() + + +def prepare_yaml_for_parsing(raw_yaml: str) -> str: + """ + Prepares raw YAML content by fixing spacing and formatting issues. + + Args: + raw_yaml (str): The raw YAML content extracted from Markdown. + + Returns: + str: The cleaned YAML content ready for parsing. + """ + # Fix sequence items that are improperly placed on the same line as their key + fixed_yaml = re.sub( + r"(\b\w+\b):\s*-\s*", r"\1:\n - ", raw_yaml + ) # Fix "key: - value" to "key:\n - value" + + # Ensure proper spacing after colons + fixed_yaml = re.sub( + r"(\S):(\S)", r"\1: \2", fixed_yaml + ) # Ensure space after colons + + # Remove trailing spaces before newlines + fixed_yaml = re.sub(r"\s+\n", "\n", fixed_yaml) + + # Replace non-breaking spaces (if any) with regular spaces + fixed_yaml = fixed_yaml.replace("\xa0", " ") + + return fixed_yaml.strip() + + +def parse_yaml_from_swarm_markdown(markdown_text: str) -> dict: + """ + Extracts and prepares YAML content from a Markdown-style 'Auto-Swarm-Builder' block and parses it. + + Args: + markdown_text (str): The Markdown text containing the YAML inside 'Auto-Swarm-Builder' block. + + Returns: + dict: A parsed Python dictionary of the YAML content. + """ + # Match the 'Auto-Swarm-Builder' block with YAML inside triple backticks + pattern = r"```yaml\s*\n(.*?)```" + match = re.search(pattern, markdown_text, re.DOTALL) + + if not match: + raise ValueError( + "No YAML content found in the 'Auto-Swarm-Builder' block." + ) + + raw_yaml = match.group(1).strip() + + # Preprocess and normalize the YAML content + normalized_yaml = prepare_yaml_for_parsing(raw_yaml) + + return normalized_yaml + + +AUTO_GEN_PROMPT = """ +You are a specialized agent responsible for creating YAML configuration files for multi-agent swarms. Your role is to generate well-structured YAML that defines both individual agents and swarm architectures based on user requirements. +Output only the yaml nothing else. You will be penalized for making mistakes + +GUIDELINES: +1. Each YAML file must contain an `agents` section with at least one agent configuration +2. Each agent configuration requires the following mandatory fields: + - agent_name (string) + - system_prompt (string) + +3. Optional agent fields include: + - max_loops (integer) + - autosave (boolean) + - dashboard (boolean) + - verbose (boolean) + - dynamic_temperature_enabled (boolean) + - saved_state_path (string) + - user_name (string) + - retry_attempts (integer) + - context_length (integer) + - return_step_meta (boolean) + - output_type (string) + - task (string) + +4. When a swarm is needed, include a `swarm_architecture` section with: + Mandatory fields: + - name (string) + - swarm_type (string: "ConcurrentWorkflow" or "SequentialWorkflow") [AgentRearrange, MixtureOfAgents, SpreadSheetSwarm, SequentialWorkflow, ConcurrentWorkflow] + + Optional fields: + - description (string) + - max_loops (integer) + - task (string) + +TEMPLATE STRUCTURE: +```yaml +agents: + - agent_name: "Agent-1-Name" + system_prompt: "Detailed system prompt here" + max_loops: 1 + # [additional optional fields] + + - agent_name: "Agent-2-Name" + system_prompt: "Detailed system prompt here" + # [additional optional fields] + +swarm_architecture: + name: "Swarm-Name" + description: "Swarm purpose and goals" + swarm_type: "ConcurrentWorkflow" + max_loops: 5 + task: "Main swarm task description" +``` + +VALIDATION RULES: +1. All agent names must be unique +2. System prompts must be clear and specific to the agent's role +3. Integer values must be positive +4. Boolean values must be true or false (lowercase) +5. File paths should use forward slashes +6. Tasks should be specific and aligned with the agent/swarm purpose + +When generating a YAML configuration: +1. Ask for specific requirements about the agents and swarm needed +2. Determine if a swarm architecture is necessary based on the task complexity +3. Generate appropriate system prompts for each agent based on their roles +4. Include relevant optional fields based on the use case +5. Validate the configuration against all rules before returning + +Example valid YAML configurations are provided below. Use these as references for structure and formatting: + +```yaml + + +agents: + - agent_name: "Data-Analysis-Agent" + system_prompt: "You are a specialized data analysis agent focused on processing and interpreting financial data. Provide clear, actionable insights based on the data provided." + max_loops: 3 + autosave: true + verbose: true + context_length: 100000 + output_type: "json" + task: "Analyze quarterly financial reports and identify trends" + +# Multi-Agent Swarm Example +agents: + - agent_name: "Research-Agent" + system_prompt: "You are a research agent specialized in gathering and summarizing scientific publications. Focus on peer-reviewed sources and provide comprehensive summaries." + max_loops: 2 + context_length: 150000 + output_type: "str" + + - agent_name: "Analysis-Agent" + system_prompt: "You are an analysis agent that processes research summaries and identifies key patterns and insights. Provide detailed analytical reports." + max_loops: 3 + context_length: 200000 + output_type: "json" + +swarm_architecture: + name: "Research-Analysis-Swarm" + description: "A swarm for comprehensive research analysis and insight generation" + swarm_type: "SequentialWorkflow" + max_loops: 5 + task: "Research and analyze recent developments in quantum computing" + +``` +""" + + +def generate_swarm_config( + task: str, + file_name: str = "swarm_config_output.yaml", + model_name: str = "gpt-4o", + *args, + **kwargs, +): + """ + Generates a swarm configuration based on the provided task and model name. + + This function attempts to generate a swarm configuration by running an agent with the specified task and model name. + It then parses the output into YAML format and creates agents based on the parsed YAML content. + + Args: + task (str): The task to be performed by the swarm. + file_name (str, optional): The file name for the output YAML configuration. Defaults to "swarm_config_output.yaml". + model_name (str, optional): The name of the model to use for the agent. Defaults to "gpt-4o". + *args: Additional positional arguments to be passed to the agent's run method. + **kwargs: Additional keyword arguments to be passed to the agent's run method. + + Returns: + Any: The output of the swarm configuration generation process. This can be a SwarmRouter instance or an error message. + """ + formatter.print_panel( + "Auto Generating Swarm...", "Auto Swarm Builder" + ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(min=4, max=10), + ) + def attempt_generate_swarm_config(): + try: + model = LiteLLM(model_name=model_name) + + # Initialize the agent + agent = Agent( + agent_name="Auto-Swarm-Builder", + system_prompt=AUTO_GEN_PROMPT, + llm=model, + max_loops=1, + dynamic_temperature_enabled=True, + saved_state_path="swarm_builder.json", + user_name="swarms_corp", + output_type="str", + ) + + # Generate output from the agent + raw_output = agent.run(task, *args, **kwargs) + yaml_content = parse_yaml_from_swarm_markdown(raw_output) + print(yaml_content) + + # Create agents from the YAML file + output = create_agents_from_yaml( + yaml_string=yaml_content, + return_type="run_swarm", + ) + + formatter.print_panel( + "Swarm configuration generated successfully.", + "Success", + ) + + return output + + except Exception as e: + formatter.print_panel( + f"Error generating swarm configuration: {str(e)}", + "Error", + ) + raise + + return attempt_generate_swarm_config() diff --git a/swarms/agents/create_agents_from_yaml.py b/swarms/agents/create_agents_from_yaml.py index 7e6e056b..e92d1923 100644 --- a/swarms/agents/create_agents_from_yaml.py +++ b/swarms/agents/create_agents_from_yaml.py @@ -1,22 +1,168 @@ import os -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import yaml +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) +from pydantic import ( + BaseModel, + Field, + field_validator, +) from swarms.utils.loguru_logger import initialize_logger - from swarms.structs.agent import Agent from swarms.structs.swarm_router import SwarmRouter - +from swarms.utils.litellm import LiteLLM logger = initialize_logger(log_folder="create_agents_from_yaml") +class AgentConfig(BaseModel): + agent_name: str + system_prompt: str + model_name: Optional[str] = None + max_loops: int = Field(default=1, ge=1) + autosave: bool = True + dashboard: bool = False + verbose: bool = False + dynamic_temperature_enabled: bool = False + saved_state_path: Optional[str] = None + user_name: str = "default_user" + retry_attempts: int = Field(default=3, ge=1) + context_length: int = Field(default=100000, ge=1000) + return_step_meta: bool = False + output_type: str = "str" + auto_generate_prompt: bool = False + artifacts_on: bool = False + artifacts_file_extension: str = ".md" + artifacts_output_path: str = "" + + @field_validator("system_prompt") + @classmethod + def validate_system_prompt(cls, v): + if not v or not isinstance(v, str) or len(v.strip()) == 0: + raise ValueError( + "System prompt must be a non-empty string" + ) + return v + + +class SwarmConfig(BaseModel): + name: str + description: str + max_loops: int = Field(default=1, ge=1) + swarm_type: str + task: Optional[str] = None + flow: Optional[Dict] = None + autosave: bool = True + return_json: bool = False + rules: str = "" + + @field_validator("swarm_type") + @classmethod + def validate_swarm_type(cls, v): + valid_types = { + "SequentialWorkflow", + "ConcurrentWorkflow", + "AgentRearrange", + "MixtureOfAgents", + "auto", + } + if v not in valid_types: + raise ValueError( + f"Swarm type must be one of: {valid_types}" + ) + return v + + +class YAMLConfig(BaseModel): + agents: List[AgentConfig] = Field(..., min_length=1) + swarm_architecture: Optional[SwarmConfig] = None + + model_config = { + "extra": "forbid" # Prevent additional fields not in the model + } + + +def load_yaml_safely( + yaml_file: str = None, yaml_string: str = None +) -> Dict: + """Safely load and validate YAML configuration using Pydantic.""" + try: + if yaml_string: + config_dict = yaml.safe_load(yaml_string) + elif yaml_file: + if not os.path.exists(yaml_file): + raise FileNotFoundError( + f"YAML file {yaml_file} not found." + ) + with open(yaml_file, "r") as file: + config_dict = yaml.safe_load(file) + else: + raise ValueError( + "Either yaml_file or yaml_string must be provided" + ) + + # Validate using Pydantic + YAMLConfig(**config_dict) + return config_dict + except yaml.YAMLError as e: + raise ValueError(f"Error parsing YAML: {str(e)}") + except Exception as e: + raise ValueError(f"Error validating configuration: {str(e)}") + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((ConnectionError, TimeoutError)), + before_sleep=lambda retry_state: logger.info( + f"Retrying after error: {retry_state.outcome.exception()}" + ), +) +def create_agent_with_retry( + agent_config: Dict, model: LiteLLM +) -> Agent: + """Create an agent with retry logic for handling transient failures.""" + try: + validated_config = AgentConfig(**agent_config) + agent = Agent( + agent_name=validated_config.agent_name, + system_prompt=validated_config.system_prompt, + llm=model, + max_loops=validated_config.max_loops, + autosave=validated_config.autosave, + dashboard=validated_config.dashboard, + verbose=validated_config.verbose, + dynamic_temperature_enabled=validated_config.dynamic_temperature_enabled, + saved_state_path=validated_config.saved_state_path, + user_name=validated_config.user_name, + retry_attempts=validated_config.retry_attempts, + context_length=validated_config.context_length, + return_step_meta=validated_config.return_step_meta, + output_type=validated_config.output_type, + auto_generate_prompt=validated_config.auto_generate_prompt, + artifacts_on=validated_config.artifacts_on, + artifacts_file_extension=validated_config.artifacts_file_extension, + artifacts_output_path=validated_config.artifacts_output_path, + ) + return agent + except Exception as e: + logger.error( + f"Error creating agent {agent_config.get('agent_name', 'unknown')}: {str(e)}" + ) + raise + + def create_agents_from_yaml( model: Callable = None, yaml_file: str = "agents.yaml", + yaml_string: str = None, return_type: str = "auto", - *args, - **kwargs, ) -> Union[ SwarmRouter, Agent, @@ -25,171 +171,99 @@ def create_agents_from_yaml( List[Dict[str, Any]], ]: """ - Create agents and/or SwarmRouter based on configurations defined in a YAML file. - - This function dynamically creates agents and a SwarmRouter (if specified) based on the - configuration in the YAML file. It adapts its behavior based on the presence of a - swarm architecture and the number of agents defined. - - Args: - model (Callable): The language model to be used by the agents. - yaml_file (str): Path to the YAML file containing agent and swarm configurations. - return_type (str): Determines the return value. Options are: - "auto" (default): Automatically determine the most appropriate return type. - "swarm": Return SwarmRouter if present, otherwise a single agent or list of agents. - "agents": Return a list of agents (or a single agent if only one is defined). - "both": Return both SwarmRouter (or single agent) and list of agents. - "tasks": Return task results if any tasks were executed. - "run_swarm": Run the swarm and return its output. - *args: Additional positional arguments for agent or SwarmRouter customization. - **kwargs: Additional keyword arguments for agent or SwarmRouter customization. - - Returns: - Union[SwarmRouter, Agent, List[Agent], Tuple[Union[SwarmRouter, Agent], List[Agent]], List[Dict[str, Any]]]: - The return type depends on the 'return_type' argument and the configuration in the YAML file. - - Raises: - FileNotFoundError: If the specified YAML file is not found. - ValueError: If the YAML configuration is invalid or if an invalid return_type is specified. + Create agents and/or SwarmRouter based on configurations defined in a YAML file or string. """ - try: - logger.info( - f"Checking if the YAML file {yaml_file} exists..." - ) - - if not os.path.exists(yaml_file): - logger.error(f"YAML file {yaml_file} not found.") - raise FileNotFoundError( - f"YAML file {yaml_file} not found." - ) - - logger.info(f"Loading YAML file {yaml_file}") - with open(yaml_file, "r") as file: - config = yaml.safe_load(file) - - if "agents" not in config: - logger.error( - "The YAML configuration does not contain 'agents'." - ) - raise ValueError( - "The YAML configuration does not contain 'agents'." - ) + agents = [] + task_results = [] + swarm_router = None - agents = [] - task_results = [] + try: + # Load and validate configuration + config = load_yaml_safely(yaml_file, yaml_string) - # Create agents + # Create agents with retry logic for agent_config in config["agents"]: logger.info( f"Creating agent: {agent_config['agent_name']}" ) - if "system_prompt" not in agent_config: - logger.error( - f"System prompt is missing for agent: {agent_config['agent_name']}" - ) - raise ValueError( - f"System prompt is missing for agent: {agent_config['agent_name']}" + if "model_name" in agent_config: + model_instance = LiteLLM( + model_name=agent_config["model_name"] ) + else: + model_name = "gpt-4o" + model_instance = LiteLLM(model_name=model_name) - agent = Agent( - agent_name=agent_config["agent_name"], - system_prompt=agent_config["system_prompt"], - llm=model, - max_loops=agent_config.get("max_loops", 1), - autosave=agent_config.get("autosave", True), - dashboard=agent_config.get("dashboard", False), - verbose=agent_config.get("verbose", False), - dynamic_temperature_enabled=agent_config.get( - "dynamic_temperature_enabled", False - ), - saved_state_path=agent_config.get("saved_state_path"), - user_name=agent_config.get( - "user_name", "default_user" - ), - retry_attempts=agent_config.get("retry_attempts", 1), - context_length=agent_config.get( - "context_length", 100000 - ), - return_step_meta=agent_config.get( - "return_step_meta", False - ), - output_type=agent_config.get("output_type", "str"), - auto_generate_prompt=agent_config.get( - "auto_generate_prompt", "False" - ), - artifacts_on=agent_config.get( - "artifacts_on", "False" - ), - artifacts_file_extension=agent_config.get( - "artifacts_file_extension", ".md" - ), - artifacts_output_path=agent_config.get( - "artifacts_output_path", "" - ), - *args, - **kwargs, + agent = create_agent_with_retry( + agent_config, model_instance ) - logger.info( f"Agent {agent_config['agent_name']} created successfully." ) agents.append(agent) - # Create SwarmRouter if swarm_architecture is present - swarm_router = None + # Create SwarmRouter if specified if "swarm_architecture" in config: - swarm_config = config["swarm_architecture"] - swarm_router = SwarmRouter( - name=swarm_config["name"], - description=swarm_config["description"], - max_loops=swarm_config["max_loops"], - agents=agents, - swarm_type=swarm_config["swarm_type"], - task=swarm_config.get("task"), - flow=swarm_config.get("flow"), - autosave=swarm_config.get("autosave"), - return_json=swarm_config.get("return_json"), - rules=swarm_config.get("rules", "") * args, - **kwargs, - ) - logger.info( - f"SwarmRouter '{swarm_config['name']}' created successfully." + try: + swarm_config = SwarmConfig( + **config["swarm_architecture"] + ) + swarm_router = SwarmRouter( + name=swarm_config.name, + description=swarm_config.description, + max_loops=swarm_config.max_loops, + agents=agents, + swarm_type=swarm_config.swarm_type, + task=swarm_config.task, + flow=swarm_config.flow, + autosave=swarm_config.autosave, + return_json=swarm_config.return_json, + rules=swarm_config.rules, + ) + logger.info( + f"SwarmRouter '{swarm_config.name}' created successfully." + ) + except Exception as e: + logger.error(f"Error creating SwarmRouter: {str(e)}") + raise ValueError( + f"Failed to create SwarmRouter: {str(e)}" + ) + + # Handle return types with improved error checking + valid_return_types = { + "auto", + "swarm", + "agents", + "both", + "tasks", + "run_swarm", + } + if return_type not in valid_return_types: + raise ValueError( + f"Invalid return_type. Must be one of: {valid_return_types}" ) - # Define function to run SwarmRouter - def run_swarm_router( - task: str = ( - swarm_config.get("task") - if "swarm_architecture" in config - else None - ), - ): - if swarm_router: - try: - output = swarm_router.run(task) - print(output) - logger.info( - f"Output for SwarmRouter '{swarm_config['name']}': {output}" - ) - return output - except Exception as e: - logger.error( - f"Error running task for SwarmRouter '{swarm_config['name']}': {e}" - ) - raise e - else: - logger.error("SwarmRouter not created.") - raise ValueError("SwarmRouter not created.") + if return_type == "run_swarm" or "swarm": + if not swarm_router: + raise ValueError( + "Cannot run swarm: SwarmRouter not created." + ) + try: + return swarm_router.run( + config["swarm_architecture"]["task"] + ) + except Exception as e: + logger.error(f"Error running SwarmRouter: {str(e)}") + raise - # Handle return types + # Return appropriate type based on configuration if return_type == "auto": - if swarm_router: - return swarm_router - elif len(agents) == 1: - return agents[0] - else: - return agents + return ( + swarm_router + if swarm_router + else (agents[0] if len(agents) == 1 else agents) + ) elif return_type == "swarm": return ( swarm_router @@ -205,24 +279,10 @@ def create_agents_from_yaml( else agents[0] if len(agents) == 1 else agents ), agents elif return_type == "tasks": - if not task_results: - logger.warning( - "No tasks were executed. Returning empty list." - ) return task_results - elif return_type == "run_swarm": - if swarm_router: - return run_swarm_router() - else: - logger.error( - "Cannot run swarm: SwarmRouter not created." - ) - raise ValueError( - "Cannot run swarm: SwarmRouter not created." - ) - else: - logger.error(f"Invalid return_type: {return_type}") - raise ValueError(f"Invalid return_type: {return_type}") + except Exception as e: - logger.error(f"An error occurred: {e}") - raise e + logger.error( + f"Critical error in create_agents_from_yaml: {str(e)}" + ) + raise diff --git a/swarms/agents/openai_assistant.py b/swarms/agents/openai_assistant.py new file mode 100644 index 00000000..acedf362 --- /dev/null +++ b/swarms/agents/openai_assistant.py @@ -0,0 +1,264 @@ +from typing import Optional, List, Dict, Any, Callable +import time +from openai import OpenAI +from swarms.structs.agent import Agent +import json + +class OpenAIAssistant(Agent): + """ + OpenAI Assistant wrapper for the swarms framework. + Integrates OpenAI's Assistants API with the swarms architecture. + + Example: + >>> assistant = OpenAIAssistant( + ... name="Math Tutor", + ... instructions="You are a personal math tutor.", + ... model="gpt-4o", + ... tools=[{"type": "code_interpreter"}] + ... ) + >>> response = assistant.run("Solve 3x + 11 = 14") + """ + + def __init__( + self, + name: str, + instructions: Optional[str] = None, + model: str = "gpt-4o", + tools: Optional[List[Dict[str, Any]]] = None, + file_ids: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + *args, + **kwargs + ): + """Initialize the OpenAI Assistant. + + Args: + name: Name of the assistant + instructions: System instructions for the assistant + model: Model to use (default: gpt-4-turbo-preview) + tools: List of tools to enable (code_interpreter, retrieval) + file_ids: List of file IDs to attach + metadata: Additional metadata + functions: List of custom functions to make available + """ + super().__init__(*args, **kwargs) + + # Initialize tools list with any provided functions + self.tools = tools or [] + if functions: + for func in functions: + self.tools.append({ + "type": "function", + "function": func + }) + + # Create the OpenAI Assistant + self.client = OpenAI() + self.assistant = self.client.beta.assistants.create( + name=name, + instructions=instructions, + model=model, + tools=self.tools, + file_ids=file_ids or [], + metadata=metadata or {} + ) + + # Store available functions + self.available_functions: Dict[str, Callable] = {} + + def add_function(self, func: Callable, description: str, parameters: Dict[str, Any]) -> None: + """Add a function that the assistant can call. + + Args: + func: The function to make available to the assistant + description: Description of what the function does + parameters: JSON schema describing the function parameters + """ + func_dict = { + "name": func.__name__, + "description": description, + "parameters": parameters + } + + # Add to tools list + self.tools.append({ + "type": "function", + "function": func_dict + }) + + # Store function reference + self.available_functions[func.__name__] = func + + # Update assistant with new tools + self.assistant = self.client.beta.assistants.update( + assistant_id=self.assistant.id, + tools=self.tools + ) + + def _handle_tool_calls(self, run, thread_id: str) -> None: + """Handle any required tool calls during a run. + + This method processes any tool calls required by the assistant during execution. + It extracts function calls, executes them with provided arguments, and submits + the results back to the assistant. + + Args: + run: The current run object from the OpenAI API + thread_id: ID of the current conversation thread + + Returns: + Updated run object after processing tool calls + + Raises: + Exception: If there are errors executing the tool calls + """ + while run.status == "requires_action": + tool_calls = run.required_action.submit_tool_outputs.tool_calls + tool_outputs = [] + + for tool_call in tool_calls: + if tool_call.type == "function": + # Get function details + function_name = tool_call.function.name + function_args = json.loads(tool_call.function.arguments) + + # Call function if available + if function_name in self.available_functions: + function_response = self.available_functions[function_name](**function_args) + tool_outputs.append({ + "tool_call_id": tool_call.id, + "output": str(function_response) + }) + + # Submit outputs back to the run + run = self.client.beta.threads.runs.submit_tool_outputs( + thread_id=thread_id, + run_id=run.id, + tool_outputs=tool_outputs + ) + + # Wait for processing + run = self._wait_for_run(run) + + return run + + def _wait_for_run(self, run) -> Any: + """Wait for a run to complete and handle any required actions. + + This method polls the OpenAI API to check the status of a run until it completes + or fails. It handles intermediate states like required actions and implements + exponential backoff. + + Args: + run: The run object to monitor + + Returns: + The completed run object + + Raises: + Exception: If the run fails or expires + """ + while True: + run = self.client.beta.threads.runs.retrieve( + thread_id=run.thread_id, + run_id=run.id + ) + + if run.status == "completed": + break + elif run.status == "requires_action": + run = self._handle_tool_calls(run, run.thread_id) + if run.status == "completed": + break + elif run.status in ["failed", "expired"]: + raise Exception(f"Run failed with status: {run.status}") + + time.sleep(3) # Wait 3 seconds before checking again + + return run + + def _ensure_thread(self): + """Ensure a thread exists for the conversation. + + This method checks if there is an active thread for the current conversation. + If no thread exists, it creates a new one. This maintains conversation context + across multiple interactions. + + Side Effects: + Sets self.thread if it doesn't exist + """ + if not self.thread: + self.thread = self.client.beta.threads.create() + + def add_message(self, content: str, file_ids: Optional[List[str]] = None) -> None: + """Add a message to the thread. + + This method adds a new user message to the conversation thread. It ensures + a thread exists before adding the message and handles file attachments. + + Args: + content: The text content of the message to add + file_ids: Optional list of file IDs to attach to the message. These must be + files that have been previously uploaded to OpenAI. + + Side Effects: + Creates a new thread if none exists + Adds the message to the thread in OpenAI's system + """ + self._ensure_thread() + self.client.beta.threads.messages.create( + thread_id=self.thread.id, + role="user", + content=content, + file_ids=file_ids or [] + ) + + def _get_response(self) -> str: + """Get the latest assistant response from the thread.""" + messages = self.client.beta.threads.messages.list( + thread_id=self.thread.id, + order="desc", + limit=1 + ) + + if not messages.data: + return "" + + message = messages.data[0] + if message.role == "assistant": + return message.content[0].text.value + return "" + + def run(self, task: str, *args, **kwargs) -> str: + """Run a task using the OpenAI Assistant. + + Args: + task: The task or prompt to send to the assistant + + Returns: + The assistant's response as a string + """ + self._ensure_thread() + + # Add the user message + self.add_message(task) + + # Create and run the assistant + run = self.client.beta.threads.runs.create( + thread_id=self.thread.id, + assistant_id=self.assistant.id, + instructions=self.instructions + ) + + # Wait for completion + run = self._wait_for_run(run) + + # Only get and return the response if run completed successfully + if run.status == "completed": + return self._get_response() + return "" + + def call(self, task: str, *args, **kwargs) -> str: + """Alias for run() to maintain compatibility with different agent interfaces.""" + return self.run(task, *args, **kwargs) \ No newline at end of file diff --git a/swarms/artifacts/__init__.py b/swarms/artifacts/__init__.py index 448d6101..a1a027b4 100644 --- a/swarms/artifacts/__init__.py +++ b/swarms/artifacts/__init__.py @@ -1,9 +1,5 @@ -from swarms.artifacts.base_artifact import BaseArtifact -from swarms.artifacts.text_artifact import TextArtifact from swarms.artifacts.main_artifact import Artifact __all__ = [ - "BaseArtifact", - "TextArtifact", "Artifact", ] diff --git a/swarms/artifacts/base_artifact.py b/swarms/artifacts/base_artifact.py deleted file mode 100644 index aad07a7b..00000000 --- a/swarms/artifacts/base_artifact.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any - - -@dataclass -class BaseArtifact(ABC): - """ - Base class for artifacts. - """ - - id: str - name: str - value: Any - - def __post_init__(self): - if self.id is None: - self.id = uuid.uuid4().hex - if self.name is None: - self.name = self.id - - @classmethod - def value_to_bytes(cls, value: Any) -> bytes: - """ - Convert the value to bytes. - """ - if isinstance(value, bytes): - return value - else: - return str(value).encode() - - @classmethod - def value_to_dict(cls, value: Any) -> dict: - """ - Convert the value to a dictionary. - """ - if isinstance(value, dict): - dict_value = value - else: - dict_value = json.loads(value) - - return {k: v for k, v in dict_value.items()} - - def to_text(self) -> str: - """ - Convert the value to text. - """ - return str(self.value) - - def __str__(self) -> str: - """ - Return a string representation of the artifact. - """ - return self.to_text() - - def __bool__(self) -> bool: - """ - Return the boolean value of the artifact. - """ - return bool(self.value) - - def __len__(self) -> int: - """ - Return the length of the artifact. - """ - return len(self.value) - - @abstractmethod - def __add__(self, other: BaseArtifact) -> BaseArtifact: - """ - Add two artifacts together. - """ - ... diff --git a/swarms/artifacts/text_artifact.py b/swarms/artifacts/text_artifact.py deleted file mode 100644 index 13ca4dfd..00000000 --- a/swarms/artifacts/text_artifact.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Callable -from swarms.artifacts.base_artifact import BaseArtifact - - -@dataclass -class TextArtifact(BaseArtifact): - """ - Represents a text artifact. - - Attributes: - value (str): The text value of the artifact. - encoding (str, optional): The encoding of the text (default is "utf-8"). - encoding_error_handler (str, optional): The error handler for encoding errors (default is "strict"). - _embedding (list[float]): The embedding of the text artifact (default is an empty list). - - Properties: - embedding (Optional[list[float]]): The embedding of the text artifact. - - Methods: - __add__(self, other: BaseArtifact) -> TextArtifact: Concatenates the text value of the artifact with another artifact. - __bool__(self) -> bool: Checks if the text value of the artifact is non-empty. - generate_embedding(self, driver: BaseEmbeddingModel) -> Optional[list[float]]: Generates the embedding of the text artifact using a given embedding model. - token_count(self, tokenizer: BaseTokenizer) -> int: Counts the number of tokens in the text artifact using a given tokenizer. - to_bytes(self) -> bytes: Converts the text value of the artifact to bytes using the specified encoding and error handler. - """ - - value: str - encoding: str = "utf-8" - encoding_error_handler: str = "strict" - tokenizer: Callable = None - _embedding: list[float] = field(default_factory=list) - - @property - def embedding(self) -> list[float] | None: - return None if len(self._embedding) == 0 else self._embedding - - def __add__(self, other: BaseArtifact) -> TextArtifact: - return TextArtifact(self.value + other.value) - - def __bool__(self) -> bool: - return bool(self.value.strip()) - - def generate_embedding(self, model) -> list[float] | None: - self._embedding.clear() - self._embedding.extend(model.embed_string(str(self.value))) - - return self.embedding - - def token_count(self) -> int: - return self.tokenizer.count_tokens(str(self.value)) - - def to_bytes(self) -> bytes: - return self.value.encode( - encoding=self.encoding, errors=self.encoding_error_handler - ) diff --git a/swarms/cli/main.py b/swarms/cli/main.py index 738deec6..5abe8b58 100644 --- a/swarms/cli/main.py +++ b/swarms/cli/main.py @@ -1,244 +1,348 @@ import argparse import os +import subprocess import time +import webbrowser from rich.console import Console +from rich.panel import Panel +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.table import Table from rich.text import Text -from swarms.cli.onboarding_process import OnboardingProcess + +from swarms.agents.auto_generate_swarm_config import ( + generate_swarm_config, +) from swarms.agents.create_agents_from_yaml import ( create_agents_from_yaml, ) -import subprocess +from swarms.cli.onboarding_process import OnboardingProcess +from swarms.utils.formatter import formatter +# Initialize console with custom styling console = Console() -ASCII_ART = """ - _________ - / _____/_ _ _______ _______ _____ ______ - \_____ \\ \/ \/ /\__ \\_ __ \/ \ / ___/ - / \\ / / __ \| | \/ Y Y \\___ \ -/_______ / \/\_/ (____ /__| |__|_| /____ > - \/ \/ \/ \/ +class SwarmCLIError(Exception): + """Custom exception for Swarm CLI errors""" + + pass + +# Color scheme +COLORS = { + "primary": "red", + "secondary": "#FF6B6B", + "accent": "#4A90E2", + "success": "#2ECC71", + "warning": "#F1C40F", + "error": "#E74C3C", + "text": "#FFFFFF", +} + +ASCII_ART = """ + ▄████████ ▄█ █▄ ▄████████ ▄████████ ▄▄▄▄███▄▄▄▄ ▄████████ + ███ ███ ███ ███ ███ ███ ███ ███ ▄██▀▀▀███▀▀▀██▄ ███ ███ + ███ █▀ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ █▀ + ███ ███ ███ ███ ███ ▄███▄▄▄▄██▀ ███ ███ ███ ███ +▀███████████ ███ ███ ▀███████████ ▀▀███▀▀▀▀▀ ███ ███ ███ ▀███████████ + ███ ███ ███ ███ ███ ▀███████████ ███ ███ ███ ███ + ▄█ ███ ███ ▄█▄ ███ ███ ███ ███ ███ ███ ███ ███ ▄█ ███ + ▄████████▀ ▀███▀███▀ ███ █▀ ███ ███ ▀█ ███ █▀ ▄████████▀ + ███ ███ """ -# Function to display the ASCII art in red +def create_spinner(text: str) -> Progress: + """Create a custom spinner with the given text.""" + return Progress( + SpinnerColumn(style=COLORS["primary"]), + TextColumn("[{task.description}]", style=COLORS["text"]), + console=console, + ) + + def show_ascii_art(): - text = Text(ASCII_ART, style="bold cyan") - console.print(text) + """Display the ASCII art with a glowing effect.""" + panel = Panel( + Text(ASCII_ART, style=f"bold {COLORS['primary']}"), + border_style=COLORS["secondary"], + title="[bold]Welcome to Swarms[/bold]", + subtitle="[dim]Power to the Swarms[/dim]", + ) + console.print(panel) -# Help command -def show_help(): - console.print( - """ - [bold cyan]Swarms CLI - Help[/bold cyan] - - [bold magenta]Commands:[/bold magenta] - [bold white]onboarding[/bold white] : Starts the onboarding process - [bold white]help[/bold white] : Shows this help message - [bold white]get-api-key[/bold white] : Retrieves your API key from the platform - [bold white]check-login[/bold white] : Checks if you're logged in and starts the cache - [bold white]read-docs[/bold white] : Redirects you to swarms cloud documentation! - [bold white]run-agents[/bold white] : Run your Agents from your specified yaml file. Specify the yaml file with path the `--yaml-file` arg. Example: `--yaml-file agents.yaml` - [bold white]generate-prompt[/bold white] : Generate a prompt through automated prompt engineering. Requires an OPENAI Key in your `.env` Example: --prompt "Generate a prompt for an agent to analyze legal docs" - [bold white]auto-upgrade[/bold white] : Automatically upgrades Swarms to the latest version - [bold white]book-call[/bold white] : Book a strategy session with our team to discuss your use case and get personalized guidance - - For more details, visit: https://docs.swarms.world - """ +def create_command_table() -> Table: + """Create a beautifully formatted table of commands.""" + table = Table( + show_header=True, + header_style=f"bold {COLORS['primary']}", + border_style=COLORS["secondary"], + title="Available Commands", + padding=(0, 2), ) - # [bold white]add-agent[/bold white] : Add an agent to the marketplace under your name. Must have a Dockerfile + your agent.yaml to publish. Learn more Here: https://docs.swarms.world/en/latest/swarms_cloud/vision/ + table.add_column("Command", style="bold white") + table.add_column("Description", style="dim white") + commands = [ + ("onboarding", "Start the interactive onboarding process"), + ("help", "Display this help message"), + ("get-api-key", "Retrieve your API key from the platform"), + ("check-login", "Verify login status and initialize cache"), + ("run-agents", "Execute agents from your YAML configuration"), + ("auto-upgrade", "Update Swarms to the latest version"), + ("book-call", "Schedule a strategy session with our team"), + ("autoswarm", "Generate and execute an autonomous swarm"), + ] -# Fetch API key from platform -def get_api_key(): + for cmd, desc in commands: + table.add_row(cmd, desc) + + return table + + +def show_help(): + """Display a beautifully formatted help message.""" console.print( - "[bold yellow]Opening the API key retrieval page...[/bold yellow]" + "\n[bold]Swarms CLI - Command Reference[/bold]\n", + style=COLORS["primary"], ) - # Simulating API key retrieval process by opening the website - import webbrowser - - webbrowser.open("https://swarms.world/platform/api-keys") - time.sleep(2) + console.print(create_command_table()) console.print( - "[bold green]Your API key is available on the dashboard.[/bold green]" + "\n[dim]For detailed documentation, visit: https://docs.swarms.world[/dim]" ) -# Redirect to docs -def redirect_to_docs(): - console.print( - "[bold yellow]Opening the Docs page...[/bold yellow]" +def show_error(message: str, help_text: str = None): + """Display error message in a formatted panel""" + error_panel = Panel( + f"[bold red]{message}[/bold red]", + title="Error", + border_style="red", ) - # Simulating API key retrieval process by opening the website - import webbrowser + console.print(error_panel) - webbrowser.open("https://docs.swarms.world") - time.sleep(2) + if help_text: + console.print(f"\n[yellow]ℹ️ {help_text}[/yellow]") -# Redirect to docs -def redirect_to_call(): +def execute_with_spinner(action: callable, text: str) -> None: + """Execute an action with a spinner animation.""" + with create_spinner(text) as progress: + task = progress.add_task(text, total=None) + result = action() + progress.remove_task(task) + return result + + +def get_api_key(): + """Retrieve API key with visual feedback.""" + with create_spinner("Opening API key portal...") as progress: + task = progress.add_task("Opening browser...") + webbrowser.open("https://swarms.world/platform/api-keys") + time.sleep(1) + progress.remove_task(task) console.print( - "[bold yellow]Opening the Call page...[/bold yellow]" + f"\n[{COLORS['success']}]✓ API key page opened in your browser[/{COLORS['success']}]" ) - # Simulating API key retrieval process by opening the website - import webbrowser - webbrowser.open("https://cal.com/swarms/swarms-strategy-session") - time.sleep(2) - -# Check and start cache (login system simulation) def check_login(): + """Verify login status with enhanced visual feedback.""" cache_file = "cache.txt" if os.path.exists(cache_file): with open(cache_file, "r") as f: - cache_content = f.read() - if cache_content == "logged_in": + if f.read() == "logged_in": + console.print( + f"[{COLORS['success']}]✓ Authentication verified[/{COLORS['success']}]" + ) + return True + + with create_spinner("Authenticating...") as progress: + task = progress.add_task("Initializing session...") + time.sleep(1) + with open(cache_file, "w") as f: + f.write("logged_in") + progress.remove_task(task) + + console.print( + f"[{COLORS['success']}]✓ Login successful![/{COLORS['success']}]" + ) + return True + + +def run_autoswarm(task: str, model: str): + """Run autoswarm with enhanced error handling""" + try: + console.print( + "[yellow]Initializing autoswarm configuration...[/yellow]" + ) + + # Set LiteLLM verbose mode for debugging + import litellm + + litellm.set_verbose = True + + # Validate inputs + if not task or task.strip() == "": + raise SwarmCLIError("Task cannot be empty") + + if not model or model.strip() == "": + raise SwarmCLIError("Model name cannot be empty") + + # Attempt to generate swarm configuration + console.print( + f"[yellow]Generating swarm for task: {task}[/yellow]" + ) + result = generate_swarm_config(task=task, model=model) + + if result: console.print( - "[bold green]You are already logged in.[/bold green]" + "[green]✓ Swarm configuration generated successfully![/green]" ) else: - console.print( - "[bold red]You are not logged in.[/bold red]" + raise SwarmCLIError( + "Failed to generate swarm configuration" + ) + + except Exception as e: + if "No YAML content found" in str(e): + show_error( + "Failed to generate YAML configuration", + "This might be due to an API key issue or invalid model configuration.\n" + + "1. Check if your OpenAI API key is set correctly\n" + + "2. Verify the model name is valid\n" + + "3. Try running with --model gpt-4", + ) + else: + show_error( + f"Error during autoswarm execution: {str(e)}", + "For debugging, try:\n" + + "1. Check your API keys are set correctly\n" + + "2. Verify your network connection\n" + + "3. Try a different model", ) - else: - console.print("[bold yellow]Logging in...[/bold yellow]") - time.sleep(2) - with open(cache_file, "w") as f: - f.write("logged_in") - console.print("[bold green]Login successful![/bold green]") def check_and_upgrade_version(): - console.print( - "[bold yellow]Checking for Swarms updates...[/bold yellow]" - ) - try: - # Check for updates using pip + """Check for updates with visual progress.""" + + def check_update(): result = subprocess.run( ["pip", "list", "--outdated", "--format=freeze"], capture_output=True, text=True, ) - outdated_packages = result.stdout.splitlines() + return result.stdout.splitlines() - # Check if Swarms is outdated - for package in outdated_packages: - if package.startswith("swarms=="): - console.print( - "[bold magenta]New version available! Upgrading...[/bold magenta]" + outdated = execute_with_spinner( + check_update, "Checking for updates..." + ) + + for package in outdated: + if package.startswith("swarms=="): + console.print( + f"[{COLORS['warning']}]↑ Update available![/{COLORS['warning']}]" + ) + with create_spinner("Upgrading Swarms...") as progress: + task = progress.add_task( + "Installing latest version..." ) subprocess.run( ["pip", "install", "--upgrade", "swarms"], check=True, ) - console.print( - "[bold green]Swarms upgraded successfully![/bold green]" - ) - return + progress.remove_task(task) + console.print( + f"[{COLORS['success']}]✓ Swarms upgraded successfully![/{COLORS['success']}]" + ) + return - console.print( - "[bold green]Swarms is up-to-date.[/bold green]" - ) - except Exception as e: - console.print( - f"[bold red]Error checking for updates: {e}[/bold red]" - ) + console.print( + f"[{COLORS['success']}]✓ Swarms is up to date![/{COLORS['success']}]" + ) -# Main CLI handler def main(): - parser = argparse.ArgumentParser(description="Swarms Cloud CLI") - - # Adding arguments for different commands - parser.add_argument( - "command", - choices=[ - "onboarding", - "help", - "get-api-key", - "check-login", - "run-agents", - "generate-prompt", # Added new command for generating prompts - "auto-upgrade", # Added new command for auto-upgrade, - "book-call", - ], - help="Command to run", - ) - parser.add_argument( - "--yaml-file", - type=str, - default="agents.yaml", - help="Specify the YAML file for running agents", - ) - parser.add_argument( - "--prompt", - type=str, - help="Specify the task for generating a prompt", - ) - parser.add_argument( - "--num-loops", - type=int, - default=1, - help="Specify the number of loops for generating a prompt", - ) - parser.add_argument( - "--autosave", - action="store_true", - help="Enable autosave for the prompt generator", - ) - parser.add_argument( - "--save-to-yaml", - action="store_true", - help="Save the generated prompt to a YAML file", - ) + try: - args = parser.parse_args() - - show_ascii_art() - - # Determine which command to run - if args.command == "onboarding": - OnboardingProcess().run() - elif args.command == "help": - show_help() - elif args.command == "get-api-key": - get_api_key() - elif args.command == "check-login": - check_login() - elif args.command == "run-agents": - create_agents_from_yaml( - yaml_file=args.yaml_file, return_type="tasks" + show_ascii_art() + + parser = argparse.ArgumentParser( + description="Swarms Cloud CLI" ) - # elif args.command == "generate-prompt": - # if ( - # args.prompt - # ): # Corrected from args.prompt_task to args.prompt - # generate_prompt( - # num_loops=args.num_loops, - # autosave=args.autosave, - # save_to_yaml=args.save_to_yaml, - # prompt=args.prompt, # Corrected from args.prompt_task to args.prompt - # ) - # else: - # console.print( - # "[bold red]Please specify a task for generating a prompt using '--prompt'.[/bold red]" - # ) - elif args.command == "auto-upgrade": - check_and_upgrade_version() - elif args.command == "book-call": - redirect_to_call() - else: - console.print( - "[bold red]Unknown command! Type 'help' for usage.[/bold red]" + parser.add_argument( + "command", + choices=[ + "onboarding", + "help", + "get-api-key", + "check-login", + "run-agents", + "auto-upgrade", + "book-call", + "autoswarm", + ], + help="Command to execute", + ) + parser.add_argument( + "--yaml-file", + type=str, + default="agents.yaml", + help="YAML configuration file path", + ) + parser.add_argument( + "--task", type=str, help="Task for autoswarm" + ) + parser.add_argument( + "--model", + type=str, + default="gpt-4", + help="Model for autoswarm", + ) + + args = parser.parse_args() + + try: + if args.command == "onboarding": + OnboardingProcess().run() + elif args.command == "help": + show_help() + elif args.command == "get-api-key": + get_api_key() + elif args.command == "check-login": + check_login() + elif args.command == "run-agents": + create_agents_from_yaml( + yaml_file=args.yaml_file, return_type="tasks" + ) + elif args.command == "auto-upgrade": + check_and_upgrade_version() + elif args.command == "book-call": + webbrowser.open( + "https://cal.com/swarms/swarms-strategy-session" + ) + elif args.command == "autoswarm": + if not args.task: + show_error( + "Missing required argument: --task", + "Example usage: python cli.py autoswarm --task 'analyze this data' --model gpt-4", + ) + exit(1) + run_autoswarm(args.task, args.model) + except Exception as e: + console.print( + f"[{COLORS['error']}]Error: {str(e)}[/{COLORS['error']}]" + ) + return + except Exception as error: + formatter.print_panel( + f"Error detected: {error} check your args" ) + raise error if __name__ == "__main__": diff --git a/swarms/cli/onboarding_process.py b/swarms/cli/onboarding_process.py index 71c063c2..edac1168 100644 --- a/swarms/cli/onboarding_process.py +++ b/swarms/cli/onboarding_process.py @@ -87,19 +87,6 @@ class OnboardingProcess: try: combined_data = {**self.user_data, **self.system_data} log_agent_data(combined_data) - # threading.Thread(target=log_agent_data(combined_data)).start() - # with open(self.auto_save_path, "w") as f: - # json.dump(combined_data, f, indent=4) - # # logger.info( - # # "User and system data successfully saved to {}", - # # self.auto_save_path, - # # ) - # with open(self.cache_save_path, "w") as f: - # json.dump(combined_data, f, indent=4) - # logger.info( - # "User and system data successfully cached in {}", - # self.cache_save_path, - # ) return # Exit the function if saving was successful except Exception as e: logger.error( diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index adb33324..e6fc5369 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -75,9 +75,27 @@ from swarms.structs.utils import ( find_token_in_text, parse_tasks, ) +from swarms.structs.swarm_router import ( + SwarmRouter, + SwarmType, + swarm_router, +) +from swarms.structs.swarm_arange import SwarmRearrange +from swarms.structs.multi_agent_exec import ( + run_agents_concurrently, + run_agents_concurrently_async, + run_single_agent, + run_agents_concurrently_multiprocess, + run_agents_sequentially, + run_agents_with_different_tasks, + run_agent_with_timeout, + run_agents_with_resource_monitoring, +) +from swarms.structs.async_workflow import AsyncWorkflow __all__ = [ "Agent", + "AsyncWorkflow", "AutoSwarm", "AutoSwarmRouter", "BaseStructure", @@ -142,6 +160,7 @@ __all__ = [ "run_agent_with_timeout", "run_agents_with_resource_monitoring", "swarm_router", + "AsyncWorkflow", "run_agents_with_tasks_concurrently", "showcase_available_agents", "GroupChatState", diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index d1f3a745..c9160b1b 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -338,6 +338,8 @@ class Agent: scheduled_run_date: Optional[datetime] = None, do_not_use_cluster_ops: bool = True, all_gpus: bool = False, + model_name: str = None, + llm_args: dict = None, *args, **kwargs, ): @@ -453,6 +455,8 @@ class Agent: self.scheduled_run_date = scheduled_run_date self.do_not_use_cluster_ops = do_not_use_cluster_ops self.all_gpus = all_gpus + self.model_name = model_name + self.llm_args = llm_args # Initialize the short term memory self.short_memory = Conversation( @@ -589,6 +593,21 @@ class Agent: # Telemetry Processor to log agent data threading.Thread(target=self.log_agent_data).start() + threading.Thread(target=self.llm_handling()) + + def llm_handling(self): + + if self.llm is None: + from swarms.utils.litellm import LiteLLM + + if self.llm_args is not None: + self.llm = LiteLLM( + model_name=self.model_name, **self.llm_args + ) + + else: + self.llm = LiteLLM(model_name=self.model_name) + def check_if_no_prompt_then_autogenerate(self, task: str = None): """ Checks if auto_generate_prompt is enabled and generates a prompt by combining agent name, description and system prompt if available. @@ -752,8 +771,11 @@ class Agent: self, task: Optional[str] = None, img: Optional[str] = None, + speech: Optional[str] = None, + video: Optional[str] = None, is_last: Optional[bool] = False, print_task: Optional[bool] = False, + generate_speech: Optional[bool] = False, *args, **kwargs, ) -> Any: @@ -951,7 +973,7 @@ class Agent: if self.interactive: logger.info("Interactive mode enabled.") - user_input = formatter.print_panel(input("You: ")) + user_input = input("You: ") # User-defined exit command if ( @@ -1015,6 +1037,11 @@ class Agent: self.artifacts_file_extension, ) + try: + self.log_agent_data() + except Exception: + pass + # More flexible output types if ( self.output_type == "string" @@ -1050,8 +1077,16 @@ class Agent: ) except Exception as error: + self.log_agent_data() logger.info( - f"Error running agent: {error} optimize your input parameter" + f"Error running agent: {error} optimize your input parameters" + ) + raise error + + except KeyboardInterrupt as error: + self.log_agent_data() + logger.info( + f"Error running agent: {error} optimize your input parameters" ) raise error @@ -1586,11 +1621,16 @@ class Agent: files = os.listdir(self.docs_folder) # Extract the text from the files + # Process each file and combine their contents + all_text = "" for file in files: - text = data_to_text(file) + file_path = os.path.join(self.docs_folder, file) + text = data_to_text(file_path) + all_text += f"\nContent from {file}:\n{text}\n" + # Add the combined content to memory return self.short_memory.add( - role=self.user_name, content=text + role=self.user_name, content=all_text ) except Exception as error: logger.error( @@ -2262,12 +2302,13 @@ class Agent: self, task: Optional[str] = None, img: Optional[str] = None, - device: str = "cpu", # gpu - device_id: int = 0, - all_cores: bool = True, + device: Optional[str] = "cpu", # gpu + device_id: Optional[int] = 0, + all_cores: Optional[bool] = True, scheduled_run_date: Optional[datetime] = None, - do_not_use_cluster_ops: bool = False, - all_gpus: bool = False, + do_not_use_cluster_ops: Optional[bool] = False, + all_gpus: Optional[bool] = False, + generate_speech: Optional[bool] = False, *args, **kwargs, ) -> Any: @@ -2314,7 +2355,12 @@ class Agent: # If cluster ops disabled, run directly if do_not_use_cluster_ops is True: logger.info("Running without cluster operations") - return self._run(task=task, img=img, *args, **kwargs) + return self._run( + task=task, + img=img, + generate_speech=generate_speech * args, + **kwargs, + ) else: return exec_callable_with_clusterops( @@ -2325,6 +2371,7 @@ class Agent: func=self._run, task=task, img=img, + generate_speech=generate_speech, *args, **kwargs, ) diff --git a/swarms/structs/async_workflow.py b/swarms/structs/async_workflow.py new file mode 100644 index 00000000..02ebe4df --- /dev/null +++ b/swarms/structs/async_workflow.py @@ -0,0 +1,62 @@ +import asyncio +from typing import Any, Callable, List, Optional +from swarms.structs.base_workflow import BaseWorkflow +from swarms.structs.agent import Agent +from swarms.utils.loguru_logger import logger + +class AsyncWorkflow(BaseWorkflow): + def __init__( + self, + name: str = "AsyncWorkflow", + agents: List[Agent] = None, + max_workers: int = 5, + dashboard: bool = False, + autosave: bool = False, + verbose: bool = False, + **kwargs + ): + super().__init__(agents=agents, **kwargs) + self.name = name + self.agents = agents or [] + self.max_workers = max_workers + self.dashboard = dashboard + self.autosave = autosave + self.verbose = verbose + self.task_pool = [] + self.results = [] + self.loop = None + + async def _execute_agent_task(self, agent: Agent, task: str) -> Any: + """Execute a single agent task asynchronously""" + try: + if self.verbose: + logger.info(f"Agent {agent.agent_name} processing task: {task}") + result = await agent.arun(task) + if self.verbose: + logger.info(f"Agent {agent.agent_name} completed task") + return result + except Exception as e: + logger.error(f"Error in agent {agent.agent_name}: {str(e)}") + return str(e) + + async def run(self, task: str) -> List[Any]: + """Run the workflow with all agents processing the task concurrently""" + if not self.agents: + raise ValueError("No agents provided to the workflow") + + try: + # Create tasks for all agents + tasks = [self._execute_agent_task(agent, task) for agent in self.agents] + + # Execute all tasks concurrently + self.results = await asyncio.gather(*tasks, return_exceptions=True) + + if self.autosave: + # TODO: Implement autosave logic here + pass + + return self.results + + except Exception as e: + logger.error(f"Error in workflow execution: {str(e)}") + raise \ No newline at end of file diff --git a/auto_swarm_builder.py b/swarms/structs/auto_swarm_builder.py similarity index 99% rename from auto_swarm_builder.py rename to swarms/structs/auto_swarm_builder.py index 8d981dda..93e542fd 100644 --- a/auto_swarm_builder.py +++ b/swarms/structs/auto_swarm_builder.py @@ -1,5 +1,3 @@ -from loguru import logger - import os from typing import List diff --git a/swarms/structs/graph_swarm.py b/swarms/structs/graph_swarm.py new file mode 100644 index 00000000..82cef523 --- /dev/null +++ b/swarms/structs/graph_swarm.py @@ -0,0 +1,665 @@ +""" +GraphSwarm: A production-grade framework for orchestrating swarms of agents +Author: Claude +License: MIT +Version: 2.0.0 +""" + +import asyncio +import json +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, Union + +import chromadb +import networkx as nx +from loguru import logger +from pydantic import BaseModel, Field + +from swarms import Agent + + +# Configure logging +logger.add( + "graphswarm.log", + rotation="500 MB", + retention="10 days", + level="INFO", + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", +) + + +class AgentOutput(BaseModel): + """Structured output from an agent.""" + + agent_name: str + timestamp: float = Field(default_factory=time.time) + output: Any + execution_time: float + error: Optional[str] = None + metadata: Dict = Field(default_factory=dict) + + +class SwarmOutput(BaseModel): + """Structured output from the entire swarm.""" + + timestamp: float = Field(default_factory=time.time) + outputs: Dict[str, AgentOutput] + execution_time: float + success: bool + error: Optional[str] = None + metadata: Dict = Field(default_factory=dict) + + +class SwarmMemory: + """Vector-based memory system for GraphSwarm using ChromaDB.""" + + def __init__(self, collection_name: str = "swarm_memories"): + """Initialize SwarmMemory with ChromaDB.""" + self.client = chromadb.Client() + + # Get or create collection + self.collection = self.client.get_or_create_collection( + name=collection_name, + metadata={"description": "GraphSwarm execution memories"}, + ) + + def store_execution(self, task: str, result: SwarmOutput): + """Store execution results in vector memory.""" + try: + # Create metadata + metadata = { + "timestamp": datetime.now().isoformat(), + "success": result.success, + "execution_time": result.execution_time, + "agent_sequence": json.dumps( + [name for name in result.outputs.keys()] + ), + "error": result.error if result.error else "", + } + + # Create document from outputs + document = { + "task": task, + "outputs": json.dumps( + { + name: { + "output": str(output.output), + "execution_time": output.execution_time, + "error": output.error, + } + for name, output in result.outputs.items() + } + ), + } + + # Store in ChromaDB + self.collection.add( + documents=[json.dumps(document)], + metadatas=[metadata], + ids=[f"exec_{datetime.now().timestamp()}"], + ) + + print("added to database") + + logger.info(f"Stored execution in memory: {task}") + + except Exception as e: + logger.error( + f"Failed to store execution in memory: {str(e)}" + ) + + def get_similar_executions(self, task: str, limit: int = 5): + """Retrieve similar past executions.""" + try: + # Query ChromaDB for similar executions + results = self.collection.query( + query_texts=[task], + n_results=limit, + include=["documents", "metadatas"], + ) + + print(results) + + if not results["documents"]: + return [] + + # Process results + executions = [] + for doc, metadata in zip( + results["documents"][0], results["metadatas"][0] + ): + doc_dict = json.loads(doc) + executions.append( + { + "task": doc_dict["task"], + "outputs": json.loads(doc_dict["outputs"]), + "success": metadata["success"], + "execution_time": metadata["execution_time"], + "agent_sequence": json.loads( + metadata["agent_sequence"] + ), + "timestamp": metadata["timestamp"], + } + ) + + return executions + + except Exception as e: + logger.error( + f"Failed to retrieve similar executions: {str(e)}" + ) + return [] + + def get_optimal_sequence(self, task: str) -> Optional[List[str]]: + """Get the most successful agent sequence for similar tasks.""" + similar_executions = self.get_similar_executions(task) + print(f"similar_executions {similar_executions}") + + if not similar_executions: + return None + + # Sort by success and execution time + successful_execs = [ + ex for ex in similar_executions if ex["success"] + ] + + if not successful_execs: + return None + + # Return sequence from most successful execution + return successful_execs[0]["agent_sequence"] + + def clear_memory(self): + """Clear all memories.""" + self.client.delete_collection(self.collection.name) + self.collection = self.client.get_or_create_collection( + name=self.collection.name + ) + + +class GraphSwarm: + """ + Enhanced framework for creating and managing swarms of collaborative agents. + """ + + def __init__( + self, + agents: Union[ + List[Agent], List[Tuple[Agent, List[str]]], None + ] = None, + max_workers: Optional[int] = None, + swarm_name: str = "Collaborative Agent Swarm", + memory_collection: str = "swarm_memory", + ): + """Initialize GraphSwarm.""" + self.graph = nx.DiGraph() + self.agents: Dict[str, Agent] = {} + self.dependencies: Dict[str, List[str]] = {} + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.swarm_name = swarm_name + self.memory_collection = memory_collection + self.memory = SwarmMemory(collection_name=memory_collection) + + if agents: + self.initialize_agents(agents) + + logger.info(f"Initialized GraphSwarm: {swarm_name}") + + def initialize_agents( + self, + agents: Union[List[Agent], List[Tuple[Agent, List[str]]]], + ): + """Initialize agents and their dependencies.""" + try: + # Handle list of Agents or (Agent, dependencies) tuples + for item in agents: + if isinstance(item, tuple): + agent, dependencies = item + else: + agent, dependencies = item, [] + + if not isinstance(agent, Agent): + raise ValueError( + f"Expected Agent object, got {type(agent)}" + ) + + self.agents[agent.agent_name] = agent + self.dependencies[agent.agent_name] = dependencies + self.graph.add_node(agent.agent_name, agent=agent) + + # Add dependencies + for dep in dependencies: + if dep not in self.agents: + raise ValueError( + f"Dependency {dep} not found for agent {agent.agent_name}" + ) + self.graph.add_edge(dep, agent.agent_name) + + self._validate_graph() + + except Exception as e: + logger.error(f"Failed to initialize agents: {str(e)}") + raise + + def _validate_graph(self): + """Validate the agent dependency graph.""" + if not self.graph.nodes(): + raise ValueError("No agents added to swarm") + + if not nx.is_directed_acyclic_graph(self.graph): + cycles = list(nx.simple_cycles(self.graph)) + raise ValueError( + f"Agent dependency graph contains cycles: {cycles}" + ) + + def _get_agent_role_description(self, agent_name: str) -> str: + """Generate a description of the agent's role in the swarm.""" + predecessors = list(self.graph.predecessors(agent_name)) + successors = list(self.graph.successors(agent_name)) + position = ( + "initial" + if not predecessors + else ("final" if not successors else "intermediate") + ) + + role = f"""You are {agent_name}, a specialized agent in the {self.swarm_name}. + Position: {position} agent in the workflow + + Your relationships:""" + + if predecessors: + role += ( + f"\nYou receive input from: {', '.join(predecessors)}" + ) + if successors: + role += f"\nYour output will be used by: {', '.join(successors)}" + + return role + + def _generate_workflow_context(self) -> str: + """Generate a description of the entire workflow.""" + execution_order = list(nx.topological_sort(self.graph)) + + workflow = f"""Workflow Overview of {self.swarm_name}: + + Processing Order: + {' -> '.join(execution_order)} + + Agent Roles: + """ + + for agent_name in execution_order: + predecessors = list(self.graph.predecessors(agent_name)) + successors = list(self.graph.successors(agent_name)) + + workflow += f"\n\n{agent_name}:" + if predecessors: + workflow += ( + f"\n- Receives from: {', '.join(predecessors)}" + ) + if successors: + workflow += f"\n- Sends to: {', '.join(successors)}" + if not predecessors and not successors: + workflow += "\n- Independent agent" + + return workflow + + def _build_agent_prompt( + self, agent_name: str, task: str, context: Dict = None + ) -> str: + """Build a comprehensive prompt for the agent including role and context.""" + prompt_parts = [ + self._get_agent_role_description(agent_name), + "\nWorkflow Context:", + self._generate_workflow_context(), + "\nYour Task:", + task, + ] + + if context: + prompt_parts.extend( + ["\nContext from Previous Agents:", str(context)] + ) + + prompt_parts.extend( + [ + "\nInstructions:", + "1. Process the task according to your role", + "2. Consider the input from previous agents when available", + "3. Provide clear, structured output", + "4. Remember that your output will be used by subsequent agents", + "\nResponse Guidelines:", + "- Provide clear, well-organized output", + "- Include relevant details and insights", + "- Highlight key findings", + "- Flag any uncertainties or issues", + ] + ) + + return "\n".join(prompt_parts) + + async def _execute_agent( + self, agent_name: str, task: str, context: Dict = None + ) -> AgentOutput: + """Execute a single agent.""" + start_time = time.time() + agent = self.agents[agent_name] + + try: + # Build comprehensive prompt + full_prompt = self._build_agent_prompt( + agent_name, task, context + ) + logger.debug(f"Prompt for {agent_name}:\n{full_prompt}") + + # Execute agent + output = await asyncio.to_thread(agent.run, full_prompt) + + return AgentOutput( + agent_name=agent_name, + output=output, + execution_time=time.time() - start_time, + metadata={ + "task": task, + "context": context, + "position_in_workflow": list( + nx.topological_sort(self.graph) + ).index(agent_name), + }, + ) + + except Exception as e: + logger.error( + f"Error executing agent {agent_name}: {str(e)}" + ) + return AgentOutput( + agent_name=agent_name, + output=None, + execution_time=time.time() - start_time, + error=str(e), + metadata={"task": task}, + ) + + async def execute(self, task: str) -> SwarmOutput: + """ + Execute the entire swarm of agents with memory integration. + + Args: + task: Initial task to execute + + Returns: + SwarmOutput: Structured output from all agents + """ + start_time = time.time() + outputs = {} + success = True + error = None + + try: + # Get similar past executions + similar_executions = self.memory.get_similar_executions( + task, limit=3 + ) + optimal_sequence = self.memory.get_optimal_sequence(task) + + # Get base execution order + base_execution_order = list( + nx.topological_sort(self.graph) + ) + + # Determine final execution order + if optimal_sequence and all( + agent in base_execution_order + for agent in optimal_sequence + ): + logger.info( + f"Using optimal sequence from memory: {optimal_sequence}" + ) + execution_order = optimal_sequence + else: + execution_order = base_execution_order + + # Get historical context if available + historical_context = {} + if similar_executions: + best_execution = similar_executions[0] + if best_execution["success"]: + historical_context = { + "similar_task": best_execution["task"], + "previous_outputs": best_execution["outputs"], + "execution_time": best_execution[ + "execution_time" + ], + "success_patterns": self._extract_success_patterns( + similar_executions + ), + } + + # Execute agents in order + for agent_name in execution_order: + try: + # Get context from dependencies and history + agent_context = { + "dependencies": { + dep: outputs[dep].output + for dep in self.graph.predecessors( + agent_name + ) + if dep in outputs + }, + "historical": historical_context, + "position": execution_order.index(agent_name), + "total_agents": len(execution_order), + } + + # Execute agent with enhanced context + output = await self._execute_agent( + agent_name, task, agent_context + ) + outputs[agent_name] = output + + # Update historical context with current execution + if output.output: + historical_context.update( + { + f"current_{agent_name}_output": output.output + } + ) + + # Check for errors + if output.error: + success = False + error = f"Agent {agent_name} failed: {output.error}" + + # Try to recover using memory + if similar_executions: + recovery_output = self._attempt_recovery( + agent_name, task, similar_executions + ) + if recovery_output: + outputs[agent_name] = recovery_output + success = True + error = None + continue + break + + except Exception as agent_error: + logger.error( + f"Error executing agent {agent_name}: {str(agent_error)}" + ) + success = False + error = f"Agent {agent_name} failed: {str(agent_error)}" + break + + # Create result + result = SwarmOutput( + outputs=outputs, + execution_time=time.time() - start_time, + success=success, + error=error, + metadata={ + "task": task, + "used_optimal_sequence": optimal_sequence + is not None, + "similar_executions_found": len( + similar_executions + ), + "execution_order": execution_order, + "historical_context_used": bool( + historical_context + ), + }, + ) + + # Store execution in memory + await self._store_execution_async(task, result) + + return result + + except Exception as e: + logger.error(f"Swarm execution failed: {str(e)}") + return SwarmOutput( + outputs=outputs, + execution_time=time.time() - start_time, + success=False, + error=str(e), + metadata={"task": task}, + ) + + def run(self, task: str) -> SwarmOutput: + """Synchronous interface to execute the swarm.""" + return asyncio.run(self.execute(task)) + + def _extract_success_patterns( + self, similar_executions: List[Dict] + ) -> Dict: + """Extract success patterns from similar executions.""" + patterns = {} + successful_execs = [ + ex for ex in similar_executions if ex["success"] + ] + + if successful_execs: + patterns = { + "common_sequences": self._find_common_sequences( + successful_execs + ), + "avg_execution_time": sum( + ex["execution_time"] for ex in successful_execs + ) + / len(successful_execs), + "successful_strategies": self._extract_strategies( + successful_execs + ), + } + + return patterns + + def _attempt_recovery( + self, + failed_agent: str, + task: str, + similar_executions: List[Dict], + ) -> Optional[AgentOutput]: + """Attempt to recover from failure using memory.""" + for execution in similar_executions: + if ( + execution["success"] + and failed_agent in execution["outputs"] + ): + historical_output = execution["outputs"][failed_agent] + + return AgentOutput( + agent_name=failed_agent, + output=historical_output["output"], + execution_time=historical_output[ + "execution_time" + ], + metadata={ + "recovered_from_memory": True, + "original_task": execution["task"], + }, + ) + return None + + async def _store_execution_async( + self, task: str, result: SwarmOutput + ): + """Asynchronously store execution in memory.""" + try: + await asyncio.to_thread( + self.memory.store_execution, task, result + ) + except Exception as e: + logger.error( + f"Failed to store execution in memory: {str(e)}" + ) + + def add_agent(self, agent: Agent, dependencies: List[str] = None): + """Add a new agent to the swarm.""" + dependencies = dependencies or [] + self.agents[agent.agent_name] = agent + self.dependencies[agent.agent_name] = dependencies + self.graph.add_node(agent.agent_name, agent=agent) + + for dep in dependencies: + if dep not in self.agents: + raise ValueError(f"Dependency {dep} not found") + self.graph.add_edge(dep, agent.agent_name) + + self._validate_graph() + + +if __name__ == "__main__": + try: + # Create agents + data_collector = Agent( + agent_name="Market-Data-Collector", + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, + ) + + trend_analyzer = Agent( + agent_name="Market-Trend-Analyzer", + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, + ) + + report_generator = Agent( + agent_name="Investment-Report-Generator", + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, + ) + + # Create swarm + swarm = GraphSwarm( + agents=[ + (data_collector, []), + (trend_analyzer, ["Market-Data-Collector"]), + (report_generator, ["Market-Trend-Analyzer"]), + ], + swarm_name="Market Analysis Intelligence Network", + ) + + # Run the swarm + result = swarm.run( + "Analyze current market trends for tech stocks and provide investment recommendations" + ) + + # Print results + print(f"Execution success: {result.success}") + print(f"Total time: {result.execution_time:.2f} seconds") + + for agent_name, output in result.outputs.items(): + print(f"\nAgent: {agent_name}") + print(f"Output: {output.output}") + if output.error: + print(f"Error: {output.error}") + except Exception as error: + logger.error(error) + raise error diff --git a/groupchat_new.py b/swarms/structs/groupchat_new.py similarity index 100% rename from groupchat_new.py rename to swarms/structs/groupchat_new.py diff --git a/swarms/structs/pulsar_swarm.py b/swarms/structs/pulsar_swarm.py new file mode 100644 index 00000000..2d8961f7 --- /dev/null +++ b/swarms/structs/pulsar_swarm.py @@ -0,0 +1,276 @@ +import asyncio +import pulsar + +from pulsar import ConsumerType +from loguru import logger +from swarms import Agent +from typing import List, Dict, Any +import json + + +class ScalableAsyncAgentSwarm: + """ + A scalable, asynchronous swarm of agents leveraging Apache Pulsar for inter-agent communication. + Provides load balancing, health monitoring, dead letter queues, and centralized logging. + """ + + def __init__( + self, + pulsar_url: str, + topic: str, + dlq_topic: str, + agents_config: List[Dict[str, Any]], + ): + """ + Initializes the async swarm with agents. + + Args: + pulsar_url (str): The URL of the Apache Pulsar broker. + topic (str): The main topic for task distribution. + dlq_topic (str): The Dead Letter Queue topic for failed messages. + agents_config (List[Dict[str, Any]]): List of agent configurations with `name`, `description`, and `model_name`. + """ + self.pulsar_url = pulsar_url + self.topic = topic + self.dlq_topic = dlq_topic + self.agents_config = agents_config + self.client = pulsar.Client(pulsar_url) + self.consumer = self.client.subscribe( + topic, + subscription_name="swarm-task-sub", + consumer_type=ConsumerType.Shared, + ) + self.dlq_producer = self.client.create_producer(dlq_topic) + self.response_logger = [] + self.agents = [ + self.create_agent(config) for config in agents_config + ] + self.agent_index = 0 + + logger.info( + "Swarm initialized with agents: {}", + [agent["name"] for agent in agents_config], + ) + + def create_agent( + self, agent_config: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Creates a new agent configuration with asynchronous capabilities. + + Args: + agent_config (Dict[str, Any]): Configuration dictionary with agent details. + + Returns: + Dict[str, Any]: A dictionary containing agent metadata and functionality. + """ + agent_name = agent_config["name"] + description = agent_config["description"] + model_name = agent_config.get("model_name", "gpt-4o-mini") + + class AsyncAgent: + """ + An asynchronous agent that processes tasks and communicates via Apache Pulsar. + """ + + def __init__( + self, name: str, description: str, model_name: str + ): + self.name = name + self.description = description + self.agent = Agent( + agent_name=name, + model_name=model_name, + max_loops="auto", + interactive=True, + streaming_on=True, + ) + logger.info( + f"Initialized agent '{name}' - {description}" + ) + + async def process_task( + self, message: str + ) -> Dict[str, Any]: + """ + Processes a single task using the agent. + + Args: + message (str): The task message. + + Returns: + Dict[str, Any]: JSON-formatted response. + """ + try: + logger.info( + f"Agent {self.name} processing task: {message}" + ) + response = await asyncio.to_thread( + self.agent.run, message + ) + logger.info(f"Agent {self.name} completed task.") + return { + "agent_name": self.name, + "response": response, + } + except Exception as e: + logger.error( + f"Agent {self.name} encountered an error: {e}" + ) + return {"agent_name": self.name, "error": str(e)} + + return { + "name": agent_name, + "instance": AsyncAgent( + agent_name, description, model_name + ), + } + + async def distribute_task(self, message: str): + """ + Distributes a task to the next available agent using round-robin. + + Args: + message (str): The task message. + """ + agent = self.agents[self.agent_index] + self.agent_index = (self.agent_index + 1) % len(self.agents) + + try: + response = await agent["instance"].process_task(message) + self.log_response(response) + except Exception as e: + logger.error( + f"Error processing task by agent {agent['name']}: {e}" + ) + self.send_to_dlq(message) + + async def monitor_health(self): + """ + Periodically monitors the health of agents. + """ + while True: + logger.info("Performing health check for all agents.") + for agent in self.agents: + logger.info(f"Agent {agent['name']} is online.") + await asyncio.sleep(10) + + def send_to_dlq(self, message: str): + """ + Sends a failed message to the Dead Letter Queue (DLQ). + + Args: + message (str): The message to send to the DLQ. + """ + try: + self.dlq_producer.send(message.encode("utf-8")) + logger.info("Message sent to Dead Letter Queue.") + except Exception as e: + logger.error(f"Failed to send message to DLQ: {e}") + + def log_response(self, response: Dict[str, Any]): + """ + Logs the response to a centralized list for later analysis. + + Args: + response (Dict[str, Any]): The agent's response. + """ + self.response_logger.append(response) + logger.info(f"Response logged: {response}") + + async def listen_and_distribute(self): + """ + Listens to the main Pulsar topic and distributes tasks to agents. + """ + while True: + msg = self.consumer.receive() + try: + message = msg.data().decode("utf-8") + logger.info(f"Received task: {message}") + await self.distribute_task(message) + self.consumer.acknowledge(msg) + except Exception as e: + logger.error(f"Error processing message: {e}") + self.send_to_dlq(msg.data().decode("utf-8")) + self.consumer.negative_acknowledge(msg) + + async def run(self): + """ + Runs the swarm asynchronously with health monitoring and task distribution. + """ + logger.info("Starting the async swarm...") + task_listener = asyncio.create_task( + self.listen_and_distribute() + ) + health_monitor = asyncio.create_task(self.monitor_health()) + await asyncio.gather(task_listener, health_monitor) + + def shutdown(self): + """ + Safely shuts down the swarm and logs all responses. + """ + logger.info("Shutting down the swarm...") + self.client.close() + with open("responses.json", "w") as f: + json.dump(self.response_logger, f, indent=4) + logger.info("Responses saved to 'responses.json'.") + + +# from scalable_agent_swarm import ScalableAsyncAgentSwarm # Assuming your swarm class is saved here + +if __name__ == "__main__": + # Example Configuration + PULSAR_URL = "pulsar://localhost:6650" + TOPIC = "stock-analysis" + DLQ_TOPIC = "stock-analysis-dlq" + + # Agents configuration + AGENTS_CONFIG = [ + { + "name": "Stock-Analysis-Agent-1", + "description": "Analyzes stock trends.", + "model_name": "gpt-4o-mini", + }, + { + "name": "Stock-News-Agent", + "description": "Summarizes stock news.", + "model_name": "gpt-4o-mini", + }, + { + "name": "Tech-Trends-Agent", + "description": "Tracks tech sector trends.", + "model_name": "gpt-4o-mini", + }, + ] + + # Tasks to send + TASKS = [ + "Analyze the trend for tech stocks in Q4 2024", + "Summarize the latest news on the S&P 500", + "Identify the top-performing sectors in the stock market", + "Provide a forecast for AI-related stocks for 2025", + ] + + # Initialize and run the swarm + swarm = ScalableAsyncAgentSwarm( + PULSAR_URL, TOPIC, DLQ_TOPIC, AGENTS_CONFIG + ) + try: + # Run the swarm in the background + swarm_task = asyncio.create_task(swarm.run()) + + # Send tasks to the topic + client = pulsar.Client(PULSAR_URL) + producer = client.create_producer(TOPIC) + + for task in TASKS: + producer.send(task.encode("utf-8")) + print(f"Sent task: {task}") + + producer.close() + client.close() + + # Keep the swarm running + asyncio.run(swarm_task) + except KeyboardInterrupt: + swarm.shutdown() diff --git a/swarms/utils/any_to_str.py b/swarms/utils/any_to_str.py index 125e233e..2b0e3809 100644 --- a/swarms/utils/any_to_str.py +++ b/swarms/utils/any_to_str.py @@ -63,40 +63,40 @@ def any_to_str(data: Union[str, Dict, List, Tuple, Any]) -> str: return f"Error converting data: {str(e)}" -def main(): - # Example 1: Dictionary - print("Dictionary:") - print( - any_to_str( - { - "name": "John", - "age": 30, - "hobbies": ["reading", "hiking"], - } - ) - ) - - print("\nNested Dictionary:") - print( - any_to_str( - { - "user": { - "id": 123, - "details": {"city": "New York", "active": True}, - }, - "data": [1, 2, 3], - } - ) - ) - - print("\nList and Tuple:") - print(any_to_str([1, "text", None, (1, 2)])) - print(any_to_str((True, False, None))) - - print("\nEmpty Collections:") - print(any_to_str([])) - print(any_to_str({})) - - -if __name__ == "__main__": - main() +# def main(): +# # Example 1: Dictionary +# print("Dictionary:") +# print( +# any_to_str( +# { +# "name": "John", +# "age": 30, +# "hobbies": ["reading", "hiking"], +# } +# ) +# ) + +# print("\nNested Dictionary:") +# print( +# any_to_str( +# { +# "user": { +# "id": 123, +# "details": {"city": "New York", "active": True}, +# }, +# "data": [1, 2, 3], +# } +# ) +# ) + +# print("\nList and Tuple:") +# print(any_to_str([1, "text", None, (1, 2)])) +# print(any_to_str((True, False, None))) + +# print("\nEmpty Collections:") +# print(any_to_str([])) +# print(any_to_str({})) + + +# if __name__ == "__main__": +# main() diff --git a/swarms/utils/calculate_func_metrics.py b/swarms/utils/calculate_func_metrics.py index bfb8a528..795e7bb2 100644 --- a/swarms/utils/calculate_func_metrics.py +++ b/swarms/utils/calculate_func_metrics.py @@ -4,7 +4,6 @@ from functools import wraps from typing import Any, Callable import psutil -from loguru import logger from pydantic import BaseModel from swarms.utils.loguru_logger import initialize_logger diff --git a/swarms/utils/callable_name.py b/swarms/utils/callable_name.py deleted file mode 100644 index 9a0b037f..00000000 --- a/swarms/utils/callable_name.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Any -import inspect -from functools import partial -import logging - - -class NameResolver: - """Utility class for resolving names of various objects""" - - @staticmethod - def get_name(obj: Any, default: str = "unnamed_callable") -> str: - """ - Get the name of any object with multiple fallback strategies. - - Args: - obj: The object to get the name from - default: Default name if all strategies fail - - Returns: - str: The resolved name - """ - strategies = [ - # Try getting __name__ attribute - lambda x: getattr(x, "__name__", None), - # Try getting class name - lambda x: ( - x.__class__.__name__ - if hasattr(x, "__class__") - else None - ), - # Try getting function name if it's a partial - lambda x: ( - x.func.__name__ if isinstance(x, partial) else None - ), - # Try getting the name from the class's type - lambda x: type(x).__name__, - # Try getting qualname - lambda x: getattr(x, "__qualname__", None), - # Try getting the module and class name - lambda x: ( - f"{x.__module__}.{x.__class__.__name__}" - if hasattr(x, "__module__") - else None - ), - # For async functions - lambda x: ( - x.__name__ if inspect.iscoroutinefunction(x) else None - ), - # For classes with custom __str__ - lambda x: ( - str(x) - if hasattr(x, "__str__") - and x.__str__ != object.__str__ - else None - ), - # For wrapped functions - lambda x: ( - getattr(x, "__wrapped__", None).__name__ - if hasattr(x, "__wrapped__") - else None - ), - ] - - # Try each strategy - for strategy in strategies: - try: - name = strategy(obj) - if name and isinstance(name, str): - return name.replace(" ", "_").replace("-", "_") - except Exception: - continue - - # Return default if all strategies fail - return default - - @staticmethod - def get_callable_details(obj: Any) -> dict: - """ - Get detailed information about a callable object. - - Returns: - dict: Dictionary containing: - - name: The resolved name - - type: The type of callable - - signature: The signature if available - - module: The module name if available - - doc: The docstring if available - """ - details = { - "name": NameResolver.get_name(obj), - "type": "unknown", - "signature": None, - "module": getattr(obj, "__module__", "unknown"), - "doc": inspect.getdoc(obj) - or "No documentation available", - } - - # Determine the type - if inspect.isclass(obj): - details["type"] = "class" - elif inspect.iscoroutinefunction(obj): - details["type"] = "async_function" - elif inspect.isfunction(obj): - details["type"] = "function" - elif isinstance(obj, partial): - details["type"] = "partial" - elif callable(obj): - details["type"] = "callable" - - # Try to get signature - try: - details["signature"] = str(inspect.signature(obj)) - except (ValueError, TypeError): - details["signature"] = "Unknown signature" - - return details - - @classmethod - def get_safe_name(cls, obj: Any, max_retries: int = 3) -> str: - """ - Safely get a name with retries and validation. - - Args: - obj: Object to get name from - max_retries: Maximum number of retry attempts - - Returns: - str: A valid name string - """ - retries = 0 - last_error = None - - while retries < max_retries: - try: - name = cls.get_name(obj) - - # Validate and clean the name - if name: - # Remove invalid characters - clean_name = "".join( - c - for c in name - if c.isalnum() or c in ["_", "."] - ) - - # Ensure it starts with a letter or underscore - if ( - not clean_name[0].isalpha() - and clean_name[0] != "_" - ): - clean_name = f"_{clean_name}" - - return clean_name - - except Exception as e: - last_error = e - retries += 1 - - # If all retries failed, generate a unique fallback name - import uuid - - fallback = f"callable_{uuid.uuid4().hex[:8]}" - logging.warning( - f"Failed to get name after {max_retries} retries. Using fallback: {fallback}. " - f"Last error: {str(last_error)}" - ) - return fallback - - -# # Example usage -# if __name__ == "__main__": -# def test_resolver(): -# # Test cases -# class TestClass: -# def method(self): -# pass - -# async def async_func(): -# pass - -# test_cases = [ -# TestClass, # Class -# TestClass(), # Instance -# async_func, # Async function -# lambda x: x, # Lambda -# partial(print, end=""), # Partial -# TestClass.method, # Method -# print, # Built-in function -# str, # Built-in class -# ] - -# resolver = NameResolver() - -# print("\nName Resolution Results:") -# print("-" * 50) -# for obj in test_cases: -# details = resolver.get_callable_details(obj) -# safe_name = resolver.get_safe_name(obj) -# print(f"\nObject: {obj}") -# print(f"Safe Name: {safe_name}") -# print(f"Details: {details}") - -# test_resolver() diff --git a/swarms/utils/dict_to_table.py b/swarms/utils/dict_to_table.py deleted file mode 100644 index e69de29b..00000000 diff --git a/swarms/utils/litellm.py b/swarms/utils/litellm.py new file mode 100644 index 00000000..5bdd208d --- /dev/null +++ b/swarms/utils/litellm.py @@ -0,0 +1,105 @@ +try: + from litellm import completion +except ImportError: + import subprocess + + subprocess.check_call(["pip", "install", "litellm"]) + import litellm + from litellm import completion + + litellm.set_verbose = True + + +class LiteLLM: + """ + This class represents a LiteLLM. + It is used to interact with the LLM model for various tasks. + """ + + def __init__( + self, + model_name: str = "gpt-4o", + system_prompt: str = None, + stream: bool = False, + temperature: float = 0.5, + max_tokens: int = 4000, + ): + """ + Initialize the LiteLLM with the given parameters. + + Args: + model_name (str, optional): The name of the model to use. Defaults to "gpt-4o". + system_prompt (str, optional): The system prompt to use. Defaults to None. + stream (bool, optional): Whether to stream the output. Defaults to False. + temperature (float, optional): The temperature for the model. Defaults to 0.5. + max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000. + """ + self.model_name = model_name + self.system_prompt = system_prompt + self.stream = stream + self.temperature = temperature + self.max_tokens = max_tokens + + def _prepare_messages(self, task: str) -> list: + """ + Prepare the messages for the given task. + + Args: + task (str): The task to prepare messages for. + + Returns: + list: A list of messages prepared for the task. + """ + messages = [] + + if self.system_prompt: # Check if system_prompt is not None + messages.append( + {"role": "system", "content": self.system_prompt} + ) + + messages.append({"role": "user", "content": task}) + + return messages + + def run(self, task: str, *args, **kwargs): + """ + Run the LLM model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments to pass to the model. + **kwargs: Additional keyword arguments to pass to the model. + + Returns: + str: The content of the response from the model. + """ + messages = self._prepare_messages(task) + + response = completion( + model=self.model_name, + messages=messages, + stream=self.stream, + temperature=self.temperature, + # max_completion_tokens=self.max_tokens, + max_tokens=self.max_tokens, + *args, + **kwargs, + ) + content = response.choices[ + 0 + ].message.content # Accessing the content + return content + + def __call__(self, task: str, *args, **kwargs): + """ + Call the LLM model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments to pass to the model. + **kwargs: Additional keyword arguments to pass to the model. + + Returns: + str: The content of the response from the model. + """ + return self.run(task, *args, **kwargs) diff --git a/swarms/utils/openai_tts.py b/swarms/utils/openai_tts.py new file mode 100644 index 00000000..3cfcbd05 --- /dev/null +++ b/swarms/utils/openai_tts.py @@ -0,0 +1,73 @@ +import os +from loguru import logger +import pygame +import requests +import tempfile +from openai import OpenAI + + +class OpenAITTS: + """ + A class to interact with OpenAI API and play the generated audio with improved streaming capabilities. + """ + + def __init__(self, *args, **kwargs): + self.client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), *args, **kwargs + ) + pygame.init() + + def run( + self, task: str, play_sound: bool = True, *args, **kwargs + ): + """ + Run a task with the OpenAI API and optionally play the generated audio with improved streaming. + + Args: + task (str): The task to be executed. + play_sound (bool): If True, play the generated audio. + + Returns: + None + """ + try: + response = self.client.audio.speech.create( + model="tts-1", + voice="nova", + input=task, + *args, + **kwargs, + ) + audio_url = response["url"] + logger.info("Task completed successfully.") + + if play_sound: + with tempfile.NamedTemporaryFile( + delete=False, suffix=".mp3" + ) as tmp_file: + with requests.get(audio_url, stream=True) as r: + r.raise_for_status() + for chunk in r.iter_content(chunk_size=8192): + tmp_file.write(chunk) + pygame.mixer.music.load(tmp_file.name) + pygame.mixer.music.play() + while pygame.mixer.music.get_busy(): + pygame.time.Clock().tick(10) + except Exception as e: + logger.error(f"Error during task execution: {str(e)}") + + +# client = OpenAITTS(api_key=os.getenv("OPENAI_API_KEY")) +# client.run("Hello world! This is a streaming test.", play_sound=True) + + +def text_to_speech( + task: str, play_sound: bool = True, *args, **kwargs +): + out = OpenAITTS().run( + task, play_sound=play_sound, *args, **kwargs + ) + return out + + +# print(text_to_speech(task="hello")) diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py index f295340c..c962c5d8 100644 --- a/swarms/utils/parse_code.py +++ b/swarms/utils/parse_code.py @@ -1,50 +1,64 @@ import re -def extract_code_from_markdown(markdown_content: str) -> str: +def extract_code_blocks_with_language(markdown_text: str): """ - Extracts code blocks from a Markdown string and returns them as a single string. + Extracts all code blocks from Markdown text along with their languages. Args: - - markdown_content (str): The Markdown content as a string. + markdown_text (str): The input Markdown text. Returns: - - str: A single string containing all the code blocks separated by newlines. + list[dict]: A list of dictionaries, each containing: + - 'language': The detected language (or 'plaintext' if none specified). + - 'content': The content of the code block. """ - # Regular expression for fenced code blocks with optional language specifier - pattern = r"```(?:\w+\n)?(.*?)```" + # Regex pattern to match code blocks and optional language specifiers + pattern = r"```(\w+)?\n(.*?)```" - # Check if markdown_content is a string - if not isinstance(markdown_content, str): - raise TypeError("markdown_content must be a string") + # Find all matches (language and content) + matches = re.findall(pattern, markdown_text, re.DOTALL) - # Find all matches of the pattern - matches = re.finditer(pattern, markdown_content, re.DOTALL) - - # Extract the content inside the backticks + # Parse results code_blocks = [] - for match in matches: - code_block = match.group(1).strip() - # Remove any leading or trailing whitespace from the code block - code_block = code_block.strip() - # Remove any empty lines from the code block - code_block = "\n".join( - [line for line in code_block.split("\n") if line.strip()] + for language, content in matches: + language = ( + language.strip() if language else "plaintext" + ) # Default to 'plaintext' + code_blocks.append( + {"language": language, "content": content.strip()} ) - code_blocks.append(code_block) - # Concatenate all code blocks separated by newlines - if code_blocks: - return "\n\n".join(code_blocks) - else: - return "" + return code_blocks + + +def extract_code_from_markdown( + markdown_text: str, language: str = None +): + """ + Extracts content of code blocks for a specific language or all blocks if no language specified. + Args: + markdown_text (str): The input Markdown text. + language (str, optional): The language to filter by (e.g., 'yaml', 'python'). -# example = """ -# hello im an agent -# ```bash -# pip install swarms -# ``` -# """ + Returns: + str: The concatenated content of matched code blocks or an empty string if none found. + """ + # Get all code blocks with detected languages + code_blocks = extract_code_blocks_with_language(markdown_text) + + # Filter by language if specified + if language: + code_blocks = [ + block["content"] + for block in code_blocks + if block["language"] == language + ] + else: + code_blocks = [ + block["content"] for block in code_blocks + ] # Include all blocks -# print(extract_code_from_markdown(example)) # Output: { "type": "function", "function": { "name": "fetch_financial_news", "parameters": { "query": "Nvidia news", "num_articles": 5 } } } + # Return concatenated content + return "\n\n".join(code_blocks) if code_blocks else "" diff --git a/swarms/utils/pdf_to_text.py b/swarms/utils/pdf_to_text.py index 90711691..8df8e065 100644 --- a/swarms/utils/pdf_to_text.py +++ b/swarms/utils/pdf_to_text.py @@ -1,14 +1,12 @@ -import sys from swarms.utils.try_except_wrapper import try_except_wrapper try: import pypdf except ImportError: - print( - "pypdf not installed. Please install it using: pip install" - " pypdf" - ) - sys.exit(1) + import subprocess + + subprocess.check_call(["python", "-m", "pip", "install", "pypdf"]) + import pypdf @try_except_wrapper diff --git a/swarms/utils/remove_json_whitespace.py b/swarms/utils/remove_json_whitespace.py deleted file mode 100644 index 0a043e7c..00000000 --- a/swarms/utils/remove_json_whitespace.py +++ /dev/null @@ -1,51 +0,0 @@ -import json - -import yaml - - -def remove_whitespace_from_json(json_string: str) -> str: - """ - Removes unnecessary whitespace from a JSON string. - - This function parses the JSON string into a Python object and then - serializes it back into a JSON string without unnecessary whitespace. - - Args: - json_string (str): The JSON string. - - Returns: - str: The JSON string with whitespace removed. - """ - parsed = json.loads(json_string) - return json.dumps(parsed, separators=(",", ":")) - - -# # Example usage for JSON -# json_string = '{"field1": 123, "field2": "example text"}' -# print(remove_whitespace_from_json(json_string)) - - -def remove_whitespace_from_yaml(yaml_string: str) -> str: - """ - Removes unnecessary whitespace from a YAML string. - - This function parses the YAML string into a Python object and then - serializes it back into a YAML string with minimized whitespace. - Note: This might change the representation style of YAML data. - - Args: - yaml_string (str): The YAML string. - - Returns: - str: The YAML string with whitespace reduced. - """ - parsed = yaml.safe_load(yaml_string) - return yaml.dump(parsed, default_flow_style=True) - - -# # Example usage for YAML -# yaml_string = """ -# field1: 123 -# field2: example text -# """ -# print(remove_whitespace_from_yaml(yaml_string)) diff --git a/test.py b/test.py deleted file mode 100644 index ce12ec1c..00000000 --- a/test.py +++ /dev/null @@ -1,292 +0,0 @@ -import torch -import torch.nn as nn -import torch.distributed as dist -from dataclasses import dataclass -from typing import Optional, Tuple, Union -from loguru import logger -import math - - -@dataclass -class StarAttentionConfig: - """Configuration for StarAttention module. - - Attributes: - hidden_size: Dimension of the model's hidden states - num_attention_heads: Number of attention heads - num_hosts: Number of hosts in the distributed system - block_size: Size of each context block - anchor_size: Size of the anchor block - dropout_prob: Dropout probability (default: 0.1) - layer_norm_eps: Layer normalization epsilon (default: 1e-12) - """ - - hidden_size: int - num_attention_heads: int - num_hosts: int - block_size: int - anchor_size: int - dropout_prob: float = 0.1 - layer_norm_eps: float = 1e-12 - - -class StarAttention(nn.Module): - """ - Implementation of Star Attention mechanism for distributed inference. - - The module implements a two-phase attention mechanism: - 1. Local Context Encoding with Anchor Blocks - 2. Query Encoding and Output Generation with Global Attention - """ - - def __init__(self, config: StarAttentionConfig): - super().__init__() - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"Hidden size {config.hidden_size} not divisible by number of attention " - f"heads {config.num_attention_heads}" - ) - - self.config = config - self.head_dim = ( - config.hidden_size // config.num_attention_heads - ) - - # Initialize components - self.query = nn.Linear(config.hidden_size, config.hidden_size) - self.key = nn.Linear(config.hidden_size, config.hidden_size) - self.value = nn.Linear(config.hidden_size, config.hidden_size) - - self.dropout = nn.Dropout(config.dropout_prob) - self.layer_norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - - # KV cache for storing computed key/value pairs - self.kv_cache = {} - - logger.info( - f"Initialized StarAttention with config: {config}" - ) - - def _split_heads( - self, tensor: torch.Tensor, num_heads: int - ) -> torch.Tensor: - """Split the last dimension into (num_heads, head_dim).""" - batch_size, seq_len, _ = tensor.size() - tensor = tensor.view( - batch_size, seq_len, num_heads, self.head_dim - ) - # Transpose to (batch_size, num_heads, seq_len, head_dim) - return tensor.transpose(1, 2) - - def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor: - """Merge the head dimension back into hidden_size.""" - batch_size, _, seq_len, _ = tensor.size() - tensor = tensor.transpose(1, 2) - return tensor.reshape( - batch_size, seq_len, self.config.hidden_size - ) - - def _compute_attention_scores( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute attention scores and weighted values.""" - # Scale dot-product attention - scores = torch.matmul( - query, key.transpose(-2, -1) - ) / math.sqrt(self.head_dim) - - if mask is not None: - scores = scores.masked_fill(mask == 0, float("-inf")) - - # Online softmax computation - attention_probs = torch.nn.functional.softmax(scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - context = torch.matmul(attention_probs, value) - - return context, attention_probs - - def phase1_local_context_encoding( - self, - input_ids: torch.Tensor, - host_id: int, - device: Union[str, torch.device] = "cuda", - ) -> None: - """ - Phase 1: Local Context Encoding with Anchor Blocks - - Args: - input_ids: Input tensor of shape (batch_size, seq_len) - host_id: ID of the current host - device: Device to run computations on - """ - logger.debug(f"Starting Phase 1 on host {host_id}") - - # Calculate block assignments - block_start = host_id * self.config.block_size - block_end = block_start + self.config.block_size - - # Get local block - local_block = input_ids[:, block_start:block_end].to(device) - - # Get anchor block (first block) - anchor_block = input_ids[:, : self.config.anchor_size].to( - device - ) - - # Compute KV pairs for local block - local_hidden = self.layer_norm(local_block) - local_key = self._split_heads( - self.key(local_hidden), self.config.num_attention_heads - ) - local_value = self._split_heads( - self.value(local_hidden), self.config.num_attention_heads - ) - - # Store in KV cache - self.kv_cache[host_id] = { - "key": local_key, - "value": local_value, - "anchor_key": ( - None - if host_id == 0 - else self._split_heads( - self.key(self.layer_norm(anchor_block)), - self.config.num_attention_heads, - ) - ), - } - - logger.debug( - f"Phase 1 complete on host {host_id}. KV cache shapes - " - f"key: {local_key.shape}, value: {local_value.shape}" - ) - - def phase2_query_encoding( - self, - query_input: torch.Tensor, - host_id: int, - is_query_host: bool, - device: Union[str, torch.device] = "cuda", - ) -> Optional[torch.Tensor]: - """ - Phase 2: Query Encoding and Output Generation - - Args: - query_input: Query tensor of shape (batch_size, seq_len, hidden_size) - host_id: ID of the current host - is_query_host: Whether this host is the query host - device: Device to run computations on - - Returns: - Output tensor if this is the query host, None otherwise - """ - logger.debug(f"Starting Phase 2 on host {host_id}") - - # Transform query - query_hidden = self.layer_norm(query_input) - query = self._split_heads( - self.query(query_hidden), self.config.num_attention_heads - ) - - # Compute local attention scores - local_context, local_probs = self._compute_attention_scores( - query, - self.kv_cache[host_id]["key"], - self.kv_cache[host_id]["value"], - ) - - if not is_query_host: - # Non-query hosts send their local attention statistics - dist.send(local_probs, dst=self.config.num_hosts - 1) - return None - - # Query host aggregates attention from all hosts - all_attention_probs = [local_probs] - for src_rank in range(self.config.num_hosts - 1): - probs = torch.empty_like(local_probs) - dist.recv(probs, src=src_rank) - all_attention_probs.append(probs) - - # Compute global attention - torch.mean(torch.stack(all_attention_probs), dim=0) - - # Final output computation - output = self._merge_heads(local_context) - output = self.dropout(output) - - logger.debug( - f"Phase 2 complete on host {host_id}. Output shape: {output.shape}" - ) - - return output - - def forward( - self, - input_ids: torch.Tensor, - query_input: torch.Tensor, - host_id: int, - is_query_host: bool, - device: Union[str, torch.device] = "cuda", - ) -> Optional[torch.Tensor]: - """ - Forward pass of the StarAttention module. - - Args: - input_ids: Input tensor of shape (batch_size, seq_len) - query_input: Query tensor of shape (batch_size, seq_len, hidden_size) - host_id: ID of the current host - is_query_host: Whether this host is the query host - device: Device to run computations on - - Returns: - Output tensor if this is the query host, None otherwise - """ - # Phase 1: Local Context Encoding - self.phase1_local_context_encoding(input_ids, host_id, device) - - # Phase 2: Query Encoding and Output Generation - return self.phase2_query_encoding( - query_input, host_id, is_query_host, device - ) - - -# Example forward pass -config = StarAttentionConfig( - hidden_size=768, - num_attention_heads=12, - num_hosts=3, - block_size=512, - anchor_size=128, -) - -# Initialize model -model = StarAttention(config) - -# Example input tensors -batch_size = 4 -seq_len = 512 -input_ids = torch.randint( - 0, 1000, (batch_size, seq_len) -) # Random input IDs -query_input = torch.randn( - batch_size, seq_len, config.hidden_size -) # Random query input - -# Example forward pass for query host (host_id = 2) -output = model( - input_ids=input_ids, - query_input=query_input, - host_id=2, - is_query_host=True, - device="cpu", -) - -print(output)