diff --git a/.github/actions/init_environment/action.yml b/.github/action.yml
similarity index 100%
rename from .github/actions/init_environment/action.yml
rename to .github/action.yml
diff --git a/.gitignore b/.gitignore
index 93f8e5c0..ac6be257 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,6 +18,7 @@ venv
swarms/agents/.DS_Store
_build
+conversation.txt
stderr_log.txt
.vscode
diff --git a/README.md b/README.md
index 4bf0fc06..bfb77944 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@ Run example in Collab:
-### `Agent` Example
+### `Agent`
- Reliable Structure that provides LLMS autonomy
- Extremely Customizeable with stopping conditions, interactivity, dynamical temperature, loop intervals, and so much more
- Enterprise Grade + Production Grade: `Agent` is designed and optimized for automating real-world tasks at scale!
@@ -127,15 +127,69 @@ for task in workflow.tasks:
```
-## `Multi Modal Autonomous Agents`
-- Run the agent with multiple modalities useful for various real-world tasks in manufacturing, logistics, and health.
+
+
+### `ModelParallelizer`
+- Concurrent Execution of Multiple Models: The ModelParallelizer allows you to run multiple models concurrently, comparing their outputs. This feature enables you to easily compare the performance and results of different models, helping you make informed decisions about which model to use for your specific task.
+
+- Plug-and-Play Integration: The structure provides a seamless integration with various models, including OpenAIChat, Anthropic, Mixtral, and Gemini. You can easily plug in any of these models and start using them without the need for extensive modifications or setup.
+
```python
-# Description: This is an example of how to use the Agent class to run a multi-modal workflow
import os
+
from dotenv import load_dotenv
-from swarms.models.gpt4_vision_api import GPT4VisionAPI
-from swarms.structs import Agent
+
+from swarms.models import Anthropic, Gemini, Mixtral, OpenAIChat
+from swarms.swarms import ModelParallelizer
+
+load_dotenv()
+
+# API Keys
+anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
+openai_api_key = os.getenv("OPENAI_API_KEY")
+gemini_api_key = os.getenv("GEMINI_API_KEY")
+
+# Initialize the models
+llm = OpenAIChat(openai_api_key=openai_api_key)
+anthropic = Anthropic(anthropic_api_key=anthropic_api_key)
+mixtral = Mixtral()
+gemini = Gemini(gemini_api_key=gemini_api_key)
+
+# Initialize the parallelizer
+llms = [llm, anthropic, mixtral, gemini]
+parallelizer = ModelParallelizer(llms)
+
+# Set the task
+task = "Generate a 10,000 word blog on health and wellness."
+
+# Run the task
+out = parallelizer.run(task)
+
+# Print the responses 1 by 1
+for i in range(len(out)):
+ print(f"Response from LLM {i}: {out[i]}")
+```
+
+
+### Simple Conversational Agent
+- Plug in and play conversational agent with `GPT4`, `Mixytral`, or any of our models
+- Reliable conversational structure to hold messages together with dynamic handling for long context conversations and interactions with auto chunking
+- Reliable, this simple system will always provide responses you want.
+
+```python
+import os
+
+from dotenv import load_dotenv
+
+from swarms import (
+ OpenAIChat,
+ Conversation,
+)
+
+conv = Conversation(
+ time_enabled=True,
+)
# Load the environment variables
load_dotenv()
@@ -144,65 +198,161 @@ load_dotenv()
api_key = os.environ.get("OPENAI_API_KEY")
# Initialize the language model
-llm = GPT4VisionAPI(
- openai_api_key=api_key,
- max_tokens=500,
-)
+llm = OpenAIChat(openai_api_key=api_key, model_name="gpt-4")
+
+# Run the language model in a loop
+def interactive_conversation(llm):
+ conv = Conversation()
+ while True:
+ user_input = input("User: ")
+ conv.add("user", user_input)
+ if user_input.lower() == "quit":
+ break
+ task = (
+ conv.return_history_as_string()
+ ) # Get the conversation history
+ out = llm(task)
+ conv.add("assistant", out)
+ print(
+ f"Assistant: {out}",
+ )
+ conv.display_conversation()
+ conv.export_conversation("conversation.txt")
+
+
+# Replace with your LLM instance
+interactive_conversation(llm)
-# Initialize the task
-task = (
- "Analyze this image of an assembly line and identify any issues such as"
- " misaligned parts, defects, or deviations from the standard assembly"
- " process. IF there is anything unsafe in the image, explain why it is"
- " unsafe and how it could be improved."
+```
+
+
+### `SwarmNetwork`
+- Efficient Task Management: SwarmNetwork's intelligent agent pool and task queue management system ensures tasks are distributed evenly across agents. This leads to efficient use of resources and faster task completion.
+
+- Scalability: SwarmNetwork can dynamically scale the number of agents based on the number of pending tasks. This means it can handle an increase in workload by adding more agents, and conserve resources when the workload is low by reducing the number of agents.
+
+- Versatile Deployment Options: With SwarmNetwork, each agent can be run on its own thread, process, container, machine, or even cluster. This provides a high degree of flexibility and allows for deployment that best suits the user's needs and infrastructure.
+
+```python
+import os
+
+from dotenv import load_dotenv
+
+# Import the OpenAIChat model and the Agent struct
+from swarms import OpenAIChat, Agent, SwarmNetwork
+
+# Load the environment variables
+load_dotenv()
+
+# Get the API key from the environment
+api_key = os.environ.get("OPENAI_API_KEY")
+
+# Initialize the language model
+llm = OpenAIChat(
+ temperature=0.5,
+ openai_api_key=api_key,
)
-img = "assembly_line.jpg"
## Initialize the workflow
-agent = Agent(
- llm=llm,
- max_loops="auto",
- autosave=True,
- dashboard=True,
- multi_modal=True
+agent = Agent(llm=llm, max_loops=1, agent_name="Social Media Manager")
+agent2 = Agent(llm=llm, max_loops=1, agent_name=" Product Manager")
+agent3 = Agent(llm=llm, max_loops=1, agent_name="SEO Manager")
+
+
+# Load the swarmnet with the agents
+swarmnet = SwarmNetwork(
+ agents=[agent, agent2, agent3],
)
+# List the agents in the swarm network
+out = swarmnet.list_agents()
+print(out)
+
# Run the workflow on a task
-agent.run(task=task, img=img)
+out = swarmnet.run_single_agent(
+ agent2.id, "Generate a 10,000 word blog on health and wellness."
+)
+print(out)
+# Run all the agents in the swarm network on a task
+out = swarmnet.run_many_agents(
+ "Generate a 10,000 word blog on health and wellness."
+)
+print(out)
+
```
-### `OmniModalAgent`
-- An agent that can understand any modality and conditionally generate any modality.
+### `Task`
+Task Execution: The Task structure allows for the execution of tasks by an assigned agent. The run method is used to execute the task. It's like a Zapier for LLMs
+
+- Task Description: Each Task can have a description, providing a human-readable explanation of what the task is intended to do.
+- Task Scheduling: Tasks can be scheduled for execution at a specific time using the schedule_time attribute.
+- Task Triggers: The set_trigger method allows for the setting of a trigger function that is executed before the task.
+- Task Actions: The set_action method allows for the setting of an action function that is executed after the task.
+- Task Conditions: The set_condition method allows for the setting of a condition function. The task will only be executed if this function returns True.
+- Task Dependencies: The add_dependency method allows for the addition of dependencies to the task. The task will only be executed if all its dependencies have been completed.
+- Task Priority: The set_priority method allows for the setting of the task's priority. Tasks with higher priority will be executed before tasks with lower priority.
+- Task History: The history attribute is a list that keeps track of all the results of the task execution. This can be useful for debugging and for tasks that need to be executed multiple times.
```python
-from swarms.agents.omni_modal_agent import OmniModalAgent, OpenAIChat
+from swarms.structs import Task, Agent
from swarms.models import OpenAIChat
from dotenv import load_dotenv
import os
+
# Load the environment variables
load_dotenv()
-# Get the API key from the environment
-api_key = os.environ.get("OPENAI_API_KEY")
-# Initialize the language model
-llm = OpenAIChat(
- temperature=0.5,
- model_name="gpt-4",
- openai_api_key=api_key,
+# Define a function to be used as the action
+def my_action():
+ print("Action executed")
+
+
+# Define a function to be used as the condition
+def my_condition():
+ print("Condition checked")
+ return True
+
+
+# Create an agent
+agent = Agent(
+ llm=OpenAIChat(openai_api_key=os.environ["OPENAI_API_KEY"]),
+ max_loops=1,
+ dashboard=False,
)
+# Create a task
+task = Task(description="What's the weather in miami", agent=agent)
+
+# Set the action and condition
+task.set_action(my_action)
+task.set_condition(my_condition)
+
+# Execute the task
+print("Executing task...")
+task.run()
+
+# Check if the task is completed
+if task.is_completed():
+ print("Task completed")
+else:
+ print("Task not completed")
+
+# Output the result of the task
+print(f"Task result: {task.result}")
+
-agent = OmniModalAgent(llm)
-agent.run("Generate a video of a swarm of fish and then make an image out of the video")
```
---
+
+## Real-World Deployment
+
### Multi-Agent Swarm for Logistics
- Swarms is a framework designed for real-world deployment here is a demo presenting a fully ready to use Swarm for a vast array of logistics tasks.
- Swarms is designed to be modular and reliable for real-world deployments.
@@ -312,8 +462,60 @@ efficiency_analysis = efficiency_agent.run(
factory_image,
)
```
+---
+
+
+## `Multi Modal Autonomous Agents`
+- Run the agent with multiple modalities useful for various real-world tasks in manufacturing, logistics, and health.
+
+```python
+# Description: This is an example of how to use the Agent class to run a multi-modal workflow
+import os
+from dotenv import load_dotenv
+from swarms.models.gpt4_vision_api import GPT4VisionAPI
+from swarms.structs import Agent
+
+# Load the environment variables
+load_dotenv()
+
+# Get the API key from the environment
+api_key = os.environ.get("OPENAI_API_KEY")
+
+# Initialize the language model
+llm = GPT4VisionAPI(
+ openai_api_key=api_key,
+ max_tokens=500,
+)
+
+# Initialize the task
+task = (
+ "Analyze this image of an assembly line and identify any issues such as"
+ " misaligned parts, defects, or deviations from the standard assembly"
+ " process. IF there is anything unsafe in the image, explain why it is"
+ " unsafe and how it could be improved."
+)
+img = "assembly_line.jpg"
+
+## Initialize the workflow
+agent = Agent(
+ llm=llm,
+ max_loops="auto",
+ autosave=True,
+ dashboard=True,
+ multi_modal=True
+)
+
+# Run the workflow on a task
+agent.run(task=task, img=img)
+
+
+```
+
+---
+
+## Multi-Modal Model APIs
-### Gemini
+### `Gemini`
- Deploy Gemini from Google with utmost reliability with our visual chain of thought prompt that enables more reliable responses
```python
import os
@@ -386,7 +588,7 @@ generated_text = inference(prompt_text)
print(generated_text)
```
-### Mixtral
+### `Mixtral`
- Utilize Mixtral in a very simple API,
- Utilize 4bit quantization for a increased speed and less memory usage
- Use Flash Attention 2.0 for increased speed and less memory usage
@@ -403,6 +605,63 @@ generated_text = mixtral.run("Generate a creative story.")
print(generated_text)
```
+
+### `Dalle3`
+```python
+from swarms import Dalle3
+
+# Create an instance of the Dalle3 class with high quality
+dalle3 = Dalle3(quality="high")
+
+# Define a text prompt
+task = "A high-quality image of a sunset"
+
+# Generate a high-quality image from the text prompt
+image_url = dalle3(task)
+
+# Print the generated image URL
+print(image_url)
+```
+
+
+### `GPT4Vision`
+```python
+from swarms.models import GPT4VisionAPI
+
+# Initialize with default API key and custom max_tokens
+api = GPT4VisionAPI(max_tokens=1000)
+
+# Define the task and image URL
+task = "Describe the scene in the image."
+img = "https://i.imgur.com/4P4ZRxU.jpeg"
+
+# Run the GPT-4 Vision model
+response = api.run(task, img)
+
+# Print the model's response
+print(response)
+```
+
+
+### Text to Video with `ZeroscopeTTV`
+
+```python
+# Import the model
+from swarms import ZeroscopeTTV
+
+# Initialize the model
+zeroscope = ZeroscopeTTV()
+
+# Specify the task
+task = "A person is walking on the street."
+
+# Generate the video!
+video_path = zeroscope(task)
+print(video_path)
+
+```
+
+
---
# Features π€
@@ -477,7 +736,7 @@ Swarms framework is not just a tool but a robust, scalable, and secure partner i
## Documentation
-- For documentation, go here, [swarms.apac.ai](https://swarms.apac.ai)
+- Out documentation is located here at: [swarms.apac.ai](https://swarms.apac.ai)
## π«Ά Contributions:
@@ -498,7 +757,7 @@ To see how to contribute, visit [Contribution guidelines](https://github.com/kye
## Discovery Call
-Book a discovery call with the Swarms team to learn how to optimize and scale your swarm! [Click here to book a time that works for you!](https://calendly.com/swarm-corp/30min?month=2023-11)
+Book a discovery call to learn how Swarms can lower your operating costs by 40% with swarms of autonomous agents in lightspeed. [Click here to book a time that works for you!](https://calendly.com/swarm-corp/30min?month=2023-11)
# License
Apache License
diff --git a/docs/swarms/models/zeroscope.md b/docs/swarms/models/zeroscope.md
new file mode 100644
index 00000000..4e634a6a
--- /dev/null
+++ b/docs/swarms/models/zeroscope.md
@@ -0,0 +1,105 @@
+# Module Name: ZeroscopeTTV
+
+## Introduction
+The ZeroscopeTTV module is a versatile zero-shot video generation model designed to create videos based on textual descriptions. This comprehensive documentation will provide you with an in-depth understanding of the ZeroscopeTTV module, its architecture, purpose, arguments, and detailed usage examples.
+
+## Purpose
+The ZeroscopeTTV module serves as a powerful tool for generating videos from text descriptions. Whether you need to create video content for various applications, visualize textual data, or explore the capabilities of ZeroscopeTTV, this module offers a flexible and efficient solution. With its easy-to-use interface, you can quickly generate videos based on your textual input.
+
+## Architecture
+The ZeroscopeTTV module is built on top of the Diffusers library, leveraging the power of diffusion models for video generation. It allows you to specify various parameters such as model name, data type, chunk size, dimensions, and more to customize the video generation process. The model performs multiple inference steps and utilizes a diffusion pipeline to generate high-quality videos.
+
+## Class Definition
+### `ZeroscopeTTV(model_name: str = "cerspense/zeroscope_v2_576w", torch_dtype=torch.float16, chunk_size: int = 1, dim: int = 1, num_inference_steps: int = 40, height: int = 320, width: int = 576, num_frames: int = 36)`
+
+#### Parameters
+- `model_name` (str, optional): The name of the pre-trained model to use. Default is "cerspense/zeroscope_v2_576w".
+- `torch_dtype` (torch.dtype, optional): The torch data type to use for computations. Default is torch.float16.
+- `chunk_size` (int, optional): The size of chunks for forward chunking. Default is 1.
+- `dim` (int, optional): The dimension along which the input is split for forward chunking. Default is 1.
+- `num_inference_steps` (int, optional): The number of inference steps to perform. Default is 40.
+- `height` (int, optional): The height of the video frames. Default is 320.
+- `width` (int, optional): The width of the video frames. Default is 576.
+- `num_frames` (int, optional): The number of frames in the video. Default is 36.
+
+## Functionality and Usage
+The ZeroscopeTTV module offers a straightforward interface for video generation. It accepts a textual task or description as input and returns the path to the generated video.
+
+### `run(task: str = None, *args, **kwargs) -> str`
+
+#### Parameters
+- `task` (str, optional): The input task or description for video generation.
+
+#### Returns
+- `str`: The path to the generated video.
+
+## Usage Examples
+### Example 1: Basic Usage
+
+```python
+from swarms.models import ZeroscopeTTV
+
+# Initialize the ZeroscopeTTV model
+zeroscope = ZeroscopeTTV()
+
+# Generate a video based on a textual description
+task = "A bird flying in the sky."
+video_path = zeroscope.run(task)
+print(f"Generated video path: {video_path}")
+```
+
+### Example 2: Custom Model and Parameters
+
+You can specify a custom pre-trained model and adjust various parameters for video generation.
+
+```python
+custom_model_name = "your_custom_model_path"
+custom_dtype = torch.float32
+custom_chunk_size = 2
+custom_dim = 2
+custom_num_inference_steps = 50
+custom_height = 480
+custom_width = 720
+custom_num_frames = 48
+
+custom_zeroscope = ZeroscopeTTV(
+ model_name=custom_model_name,
+ torch_dtype=custom_dtype,
+ chunk_size=custom_chunk_size,
+ dim=custom_dim,
+ num_inference_steps=custom_num_inference_steps,
+ height=custom_height,
+ width=custom_width,
+ num_frames=custom_num_frames,
+)
+
+task = "A car driving on the road."
+video_path = custom_zeroscope.run(task)
+print(f"Generated video path: {video_path}")
+```
+
+### Example 3: Exporting Video Frames
+
+You can also export individual video frames if needed.
+
+```python
+from swarms.models import export_to_video
+
+# Generate video frames
+video_frames = zeroscope.run("A boat sailing on the water.")
+
+# Export video frames to a video file
+video_path = export_to_video(video_frames)
+print(f"Generated video path: {video_path}")
+```
+
+## Additional Information and Tips
+- Ensure that the input textual task or description is clear and descriptive to achieve the desired video output.
+- Experiment with different parameter settings to control video resolution, frame count, and inference steps.
+- Use the `export_to_video` function to export individual video frames as needed.
+- Monitor the progress and output paths to access the generated videos.
+
+## Conclusion
+The ZeroscopeTTV module is a powerful solution for zero-shot video generation based on textual descriptions. Whether you are creating videos for storytelling, data visualization, or other applications, ZeroscopeTTV offers a versatile and efficient way to bring your text to life. With a flexible interface and customizable parameters, it empowers you to generate high-quality videos with ease.
+
+If you encounter any issues or have questions about using ZeroscopeTTV, please refer to the Diffusers library documentation or reach out to their support team for further assistance. Enjoy creating videos with ZeroscopeTTV!
\ No newline at end of file
diff --git a/docs/swarms/swarms/godmode.md b/docs/swarms/swarms/godmode.md
index a0965c94..6655c954 100644
--- a/docs/swarms/swarms/godmode.md
+++ b/docs/swarms/swarms/godmode.md
@@ -1,4 +1,4 @@
-# `GodMode` Documentation
+# `ModelParallelizer` Documentation
## Table of Contents
1. [Understanding the Purpose](#understanding-the-purpose)
@@ -11,19 +11,19 @@
## 1. Understanding the Purpose
-To create comprehensive documentation for the `GodMode` class, let's begin by understanding its purpose and functionality.
+To create comprehensive documentation for the `ModelParallelizer` class, let's begin by understanding its purpose and functionality.
### Purpose and Functionality
-`GodMode` is a class designed to facilitate the orchestration of multiple Language Model Models (LLMs) to perform various tasks simultaneously. It serves as a powerful tool for managing, distributing, and collecting responses from these models.
+`ModelParallelizer` is a class designed to facilitate the orchestration of multiple Language Model Models (LLMs) to perform various tasks simultaneously. It serves as a powerful tool for managing, distributing, and collecting responses from these models.
Key features and functionality include:
-- **Parallel Task Execution**: `GodMode` can distribute tasks to multiple LLMs and execute them in parallel, improving efficiency and reducing response time.
+- **Parallel Task Execution**: `ModelParallelizer` can distribute tasks to multiple LLMs and execute them in parallel, improving efficiency and reducing response time.
- **Structured Response Presentation**: The class presents the responses from LLMs in a structured tabular format, making it easy for users to compare and analyze the results.
-- **Task History Tracking**: `GodMode` keeps a record of tasks that have been submitted, allowing users to review previous tasks and responses.
+- **Task History Tracking**: `ModelParallelizer` keeps a record of tasks that have been submitted, allowing users to review previous tasks and responses.
- **Asynchronous Execution**: The class provides options for asynchronous task execution, which can be particularly useful for handling a large number of tasks.
@@ -33,29 +33,29 @@ Now that we have an understanding of its purpose, let's proceed to provide a det
### Overview
-The `GodMode` class is a crucial component for managing and utilizing multiple LLMs in various natural language processing (NLP) tasks. Its architecture and functionality are designed to address the need for parallel processing and efficient response handling.
+The `ModelParallelizer` class is a crucial component for managing and utilizing multiple LLMs in various natural language processing (NLP) tasks. Its architecture and functionality are designed to address the need for parallel processing and efficient response handling.
### Importance and Relevance
-In the rapidly evolving field of NLP, it has become common to use multiple language models to achieve better results in tasks such as translation, summarization, and question answering. `GodMode` streamlines this process by allowing users to harness the capabilities of several LLMs simultaneously.
+In the rapidly evolving field of NLP, it has become common to use multiple language models to achieve better results in tasks such as translation, summarization, and question answering. `ModelParallelizer` streamlines this process by allowing users to harness the capabilities of several LLMs simultaneously.
Key points:
-- **Parallel Processing**: `GodMode` leverages multithreading to execute tasks concurrently, significantly reducing the time required for processing.
+- **Parallel Processing**: `ModelParallelizer` leverages multithreading to execute tasks concurrently, significantly reducing the time required for processing.
- **Response Visualization**: The class presents responses in a structured tabular format, enabling users to visualize and analyze the outputs from different LLMs.
-- **Task Tracking**: Developers can track the history of tasks submitted to `GodMode`, making it easier to manage and monitor ongoing work.
+- **Task Tracking**: Developers can track the history of tasks submitted to `ModelParallelizer`, making it easier to manage and monitor ongoing work.
### Architecture and How It Works
-The architecture and working of `GodMode` can be summarized in four steps:
+The architecture and working of `ModelParallelizer` can be summarized in four steps:
-1. **Task Reception**: `GodMode` receives a task from the user.
+1. **Task Reception**: `ModelParallelizer` receives a task from the user.
2. **Task Distribution**: The class distributes the task to all registered LLMs.
-3. **Response Collection**: `GodMode` collects the responses generated by the LLMs.
+3. **Response Collection**: `ModelParallelizer` collects the responses generated by the LLMs.
4. **Response Presentation**: Finally, the class presents the responses from all LLMs in a structured tabular format, making it easy for users to compare and analyze the results.
@@ -65,15 +65,15 @@ Now that we have an overview, let's proceed with a detailed class definition.
### Class Attributes
-- `llms`: A list of LLMs (Language Model Models) that `GodMode` manages.
+- `llms`: A list of LLMs (Language Model Models) that `ModelParallelizer` manages.
- `last_responses`: Stores the responses from the most recent task.
-- `task_history`: Keeps a record of all tasks submitted to `GodMode`.
+- `task_history`: Keeps a record of all tasks submitted to `ModelParallelizer`.
### Methods
-The `GodMode` class defines various methods to facilitate task distribution, execution, and response presentation. Let's examine some of the key methods:
+The `ModelParallelizer` class defines various methods to facilitate task distribution, execution, and response presentation. Let's examine some of the key methods:
- `run(task)`: Distributes a task to all LLMs, collects responses, and returns them.
@@ -87,23 +87,23 @@ The `GodMode` class defines various methods to facilitate task distribution, exe
- `save_responses_to_file(filename)`: Saves responses to a file for future reference.
-- `load_llms_from_file(filename)`: Loads LLMs from a file, making it easy to configure `GodMode` for different tasks.
+- `load_llms_from_file(filename)`: Loads LLMs from a file, making it easy to configure `ModelParallelizer` for different tasks.
- `get_task_history()`: Retrieves the task history, allowing users to review previous tasks.
- `summary()`: Provides a summary of task history and the last responses, aiding in post-processing and analysis.
-Now that we have covered the class definition, let's delve into the functionality and usage of `GodMode`.
+Now that we have covered the class definition, let's delve into the functionality and usage of `ModelParallelizer`.
## 4. Functionality and Usage
### Distributing a Task and Collecting Responses
-One of the primary use cases of `GodMode` is to distribute a task to all registered LLMs and collect their responses. This can be achieved using the `run(task)` method. Below is an example:
+One of the primary use cases of `ModelParallelizer` is to distribute a task to all registered LLMs and collect their responses. This can be achieved using the `run(task)` method. Below is an example:
```python
-god_mode = GodMode(llms)
-responses = god_mode.run("Translate the following English text to French: 'Hello, how are you?'")
+parallelizer = ModelParallelizer(llms)
+responses = parallelizer.run("Translate the following English text to French: 'Hello, how are you?'")
```
### Printing Responses
@@ -111,7 +111,7 @@ responses = god_mode.run("Translate the following English text to French: 'Hello
To present the responses from all LLMs in a structured tabular format, use the `print_responses(task)` method. Example:
```python
-god_mode.print_responses("Summarize the main points of 'War and Peace.'")
+parallelizer.print_responses("Summarize the main points of 'War and Peace.'")
```
### Saving Responses to a File
@@ -119,15 +119,15 @@ god_mode.print_responses("Summarize the main points of 'War and Peace.'")
Users can save the responses to a file using the `save_responses_to_file(filename)` method. This is useful for archiving and reviewing responses later. Example:
```python
-god_mode.save_responses_to_file("responses.txt")
+parallelizer.save_responses_to_file("responses.txt")
```
### Task History
-The `GodMode` class keeps track of the task history. Developers can access the task history using the `get_task_history()` method. Example:
+The `ModelParallelizer` class keeps track of the task history. Developers can access the task history using the `get_task_history()` method. Example:
```python
-task_history = god_mode.get_task_history()
+task_history = parallelizer.get_task_history()
for i, task in enumerate(task_history):
print(f"Task {i + 1}: {task}")
```
@@ -136,7 +136,7 @@ for i, task in enumerate(task_history):
### Parallel Execution
-`GodMode` employs multithreading to execute tasks concurrently. This parallel processing capability significantly improves the efficiency of handling multiple tasks simultaneously.
+`ModelParallelizer` employs multithreading to execute tasks concurrently. This parallel processing capability significantly improves the efficiency of handling multiple tasks simultaneously.
### Response Visualization
@@ -144,13 +144,13 @@ The structured tabular format used for presenting responses simplifies the compa
## 6. Examples
-Let's explore additional usage examples to illustrate the versatility of `GodMode` in handling various NLP tasks.
+Let's explore additional usage examples to illustrate the versatility of `ModelParallelizer` in handling various NLP tasks.
### Example 1: Sentiment Analysis
```python
from swarms.models import OpenAIChat
-from swarms.swarms import GodMode
+from swarms.swarms import ModelParallelizer
from swarms.workers.worker import Worker
# Create an instance of an LLM for sentiment analysis
@@ -184,15 +184,15 @@ worker3 = Worker(
temperature=0.5,
)
-# Register the worker agents with GodMode
+# Register the worker agents with ModelParallelizer
agents = [worker1, worker2, worker3]
-god_mode = GodMode(agents)
+parallelizer = ModelParallelizer(agents)
# Task for sentiment analysis
task = "Please analyze the sentiment of the following sentence: 'This movie is amazing!'"
# Print responses from all agents
-god_mode.print_responses(task)
+parallelizer.print_responses(task)
```
### Example 2: Translation
@@ -200,22 +200,22 @@ god_mode.print_responses(task)
```python
from swarms.models import OpenAIChat
-from swarms.swarms import GodMode
+from swarms.swarms import ModelParallelizer
# Define LLMs for translation tasks
translator1 = OpenAIChat(model_name="translator-en-fr", openai_api_key="api-key", temperature=0.7)
translator2 = OpenAIChat(model_name="translator-en-es", openai_api_key="api-key", temperature=0.7)
translator3 = OpenAIChat(model_name="translator-en-de", openai_api_key="api-key", temperature=0.7)
-# Register translation agents with GodMode
+# Register translation agents with ModelParallelizer
translators = [translator1, translator2, translator3]
-god_mode = GodMode(translators)
+parallelizer = ModelParallelizer(translators)
# Task for translation
task = "Translate the following English text to French: 'Hello, how are you?'"
# Print translated responses from all agents
-god_mode.print_responses(task)
+parallelizer.print_responses(task)
```
### Example 3: Summarization
@@ -223,7 +223,7 @@ god_mode.print_responses(task)
```python
from swarms.models import OpenAIChat
-from swarms.swarms import GodMode
+from swarms.swarms import ModelParallelizer
# Define LLMs for summarization tasks
@@ -231,19 +231,19 @@ summarizer1 = OpenAIChat(model_name="summarizer-en", openai_api_key="api-key", t
summarizer2 = OpenAIChat(model_name="summarizer-en", openai_api_key="api-key", temperature=0.6)
summarizer3 = OpenAIChat(model_name="summarizer-en", openai_api_key="api-key", temperature=0.6)
-# Register summarization agents with GodMode
+# Register summarization agents with ModelParallelizer
summarizers = [summarizer1, summarizer2, summarizer3]
-god_mode = GodMode(summarizers)
+parallelizer = ModelParallelizer(summarizers)
# Task for summarization
task = "Summarize the main points of the article titled 'Climate Change and Its Impact on the Environment.'"
# Print summarized responses from all agents
-god_mode.print_responses(task)
+parallelizer.print_responses(task)
```
## 7. Conclusion
-In conclusion, the `GodMode` class is a powerful tool for managing and orchestrating multiple Language Model Models in natural language processing tasks. Its ability to distribute tasks, collect responses, and present them in a structured format makes it invaluable for streamlining NLP workflows. By following the provided documentation, users can harness the full potential of `GodMode` to enhance their natural language processing projects.
+In conclusion, the `ModelParallelizer` class is a powerful tool for managing and orchestrating multiple Language Model Models in natural language processing tasks. Its ability to distribute tasks, collect responses, and present them in a structured format makes it invaluable for streamlining NLP workflows. By following the provided documentation, users can harness the full potential of `ModelParallelizer` to enhance their natural language processing projects.
For further information on specific LLMs or advanced usage, refer to the documentation of the respective models and their APIs. Additionally, external resources on parallel execution and response visualization can provide deeper insights into these topics.
\ No newline at end of file
diff --git a/docs/swarms/utils/check_device.md b/docs/swarms/utils/check_device.md
new file mode 100644
index 00000000..bdb8c780
--- /dev/null
+++ b/docs/swarms/utils/check_device.md
@@ -0,0 +1,86 @@
+# check_device
+
+# Module/Function Name: check_device
+
+The `check_device` is a utility function in PyTorch designed to identify and return the appropriate device(s) for CUDA processing. If CUDA is not available, a CPU device is returned. If CUDA is available, the function returns a list of all available GPU devices.
+
+The function examines the CUDA availability, checks for multiple GPUs, and finds additional properties for each device.
+
+## Function Signature and Arguments
+
+**Signature:**
+```python
+def check_device(
+ log_level: Any = logging.INFO,
+ memory_threshold: float = 0.8,
+ capability_threshold: float = 3.5,
+ return_type: str = "list",
+) -> Union[torch.device, List[torch.device]]
+```
+
+| Parameter | Data Type | Default Value | Description |
+| ------------- | ------------- | ------------- | ------------- |
+| `log_level` | Any | logging.INFO | The log level. |
+| `memory_threshold` | float | 0.8 | It is used to check the threshold of memory used on the GPU(s). |
+| `capability_threshold` | float | 3.5 | It is used to consider only those GPU(s) which have higher compute capability compared to the threshold. |
+| `return_type` | str | "list" | Depending on the `return_type` either a list of devices can be returned or a single device. |
+
+This function does not take any mandatory argument. However, it supports optional arguments such as `log_level`, `memory_threshold`, `capability_threshold`, and `return_type`.
+
+**Returns:**
+
+- A single torch.device if one device or list of torch.devices if multiple CUDA devices are available, else returns the CPU device if CUDA is not available.
+
+## Usage and Examples
+
+### Example 1: Basic Usage
+
+```python
+import torch
+import logging
+from swarms.utils import check_device
+
+# Basic usage
+device = check_device(
+ log_level=logging.INFO,
+ memory_threshold=0.8,
+ capability_threshold=3.5,
+ return_type="list"
+)
+```
+
+### Example 2: Using CPU when CUDA is not available
+
+```python
+import torch
+import logging
+from swarms.utils import check_device
+
+# When CUDA is not available
+device = check_device()
+print(device) # If CUDA is not available it should return torch.device('cpu')
+```
+
+### Example 3: Multiple GPU Available
+
+```python
+import torch
+import logging
+from swarms.utils import check_device
+
+# When multiple GPUs are available
+device = check_device()
+print(device) # Should return a list of available GPU devices
+```
+
+## Tips and Additional Information
+
+- This function is useful when a user wants to exploit CUDA capabilities for faster computation but unsure of the available devices. This function abstracts all the necessary checks and provides a list of CUDA devices to the user.
+- The `memory_threshold` and `capability_threshold` are utilized to filter the GPU devices. The GPUs which have memory usage above the `memory_threshold` and compute capability below the `capability_threshold` are not considered.
+- As of now, CPU does not have memory or capability values, therefore, in the respective cases, it will be returned as default without any comparison.
+
+## Relevant Resources
+
+- For more details about the CUDA properties functions used (`torch.cuda.get_device_capability, torch.cuda.get_device_properties`), please refer to the official PyTorch [CUDA semantics documentation](https://pytorch.org/docs/stable/notes/cuda.html).
+- For more information about Torch device objects, you can refer to the official PyTorch [device documentation](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device).
+- For a better understanding of how the `logging` module works in Python, see the official Python [logging documentation](https://docs.python.org/3/library/logging.html).
diff --git a/docs/swarms/utils/display_markdown_message.md b/docs/swarms/utils/display_markdown_message.md
new file mode 100644
index 00000000..c1e3f894
--- /dev/null
+++ b/docs/swarms/utils/display_markdown_message.md
@@ -0,0 +1,86 @@
+# display_markdown_message
+
+# Module Name: `display_markdown_message`
+
+## Introduction
+
+`display_markdown_message` is a useful utility function for creating visually-pleasing markdown messages within Python scripts. This function automatically manages multiline strings with lots of indentation and makes single-line messages with ">" tags easy to read, providing users with convenient and elegant logging or messaging capacity.
+
+## Function Definition and Arguments
+
+Function Definition:
+```python
+def display_markdown_message(message: str, color: str = "cyan"):
+ ```
+This function accepts two parameters:
+
+|Parameter |Type |Default Value |Description |
+|--- |--- |--- |--- |
+|message |str |None |This is the message that is to be displayed. This should be a string. It can contain markdown syntax.|
+|color |str |"cyan" |This allows you to choose the color of the message. Default is "cyan". Accepts any valid color name.|
+
+## Functionality and Usage
+
+This utility function is used to display a markdown formatted message on the console. It accepts a message as a string and an optional color for the message. The function is ideal for generating stylized print outputs such as headers, status updates or pretty notifications.
+
+By default, any text within the string which is enclosed within `>` tags or `---` is treated specially:
+
+- Lines encased in `>` tags are rendered as a blockquote in markdown.
+- Lines consisting of `---` are rendered as horizontal rules.
+
+The function automatically strips off leading and trailing whitespaces from any line within the message, maintaining aesthetic consistency in your console output.
+
+### Usage Examples
+
+#### Basic Example
+
+```python
+display_markdown_message("> This is an important message", color="red")
+```
+
+Output:
+```md
+> **This is an important message**
+```
+
+This example will print out the string "This is an important message" in red color, enclosed in a blockquote tag.
+
+#### Multiline Example
+
+```python
+message = """
+> Header
+
+My normal message here.
+
+---
+
+Another important information
+"""
+display_markdown_message(message, color="green")
+```
+
+Output:
+```md
+> **Header**
+
+My normal message here.
+_____
+
+Another important information
+```
+The output is a green colored markdown styled text with the "Header" enclosed in a blockquote, followed by the phrase "My normal message here", a horizontal rule, and finally another phrase, "Another important information".
+
+## Additional Information
+
+Use newline characters `\n` to separate the lines of the message. Remember, each line of the message is stripped of leading and trailing whitespaces. If you have special markdown requirements, you may need to revise the input message string accordingly.
+
+Also, keep in mind the console or terminal's ability to display the chosen color. If a particular console does not support the chosen color, the output may fallback to the default console color.
+
+For a full list of color names supported by the `Console` module, refer to the official [Console documentation](http://console.readthedocs.io/).
+
+## References and Resources
+
+- Python Strings: https://docs.python.org/3/tutorial/introduction.html#strings
+- Python Markdown: https://pypi.org/project/markdown/
+- Console module: https://console.readthedocs.io/
diff --git a/docs/swarms/utils/extract_code_from_markdown.md b/docs/swarms/utils/extract_code_from_markdown.md
new file mode 100644
index 00000000..f6f76835
--- /dev/null
+++ b/docs/swarms/utils/extract_code_from_markdown.md
@@ -0,0 +1,114 @@
+# extract_code_from_markdown
+
+# swarms.utils Module
+
+The `swarms.utils` module provides utility functions designed to facilitate specific tasks within the main Swarm codebase. The function `extract_code_from_markdown` is a critical function within this module that we will document in this example.
+
+## Overview and Introduction
+
+Many software projects use Markdown extensively for writing documentation, tutorials, and other text documents that can be easily rendered and viewed in different formats, including HTML.
+
+The `extract_code_from_markdown` function plays a crucial role within the swarms.utils library. As developers write large volumes of Markdown, they often need to isolate code snippets from the whole Markdown file body. These isolated snippets can be used to generate test cases, transform into other languages, or analyze for metrics.
+
+## Function Definition: `extract_code_from_markdown`
+
+```python
+def extract_code_from_markdown(markdown_content: str) -> str:
+ """
+ Extracts code blocks from a Markdown string and returns them as a single string.
+
+ Args:
+ - markdown_content (str): The Markdown content as a string.
+
+ Returns:
+ - str: A single string containing all the code blocks separated by newlines.
+ """
+ # Regular expression for fenced code blocks
+ pattern = r"```(?:\w+\n)?(.*?)```"
+ matches = re.findall(pattern, markdown_content, re.DOTALL)
+
+ # Concatenate all code blocks separated by newlines
+ return "\n".join(code.strip() for code in matches)
+```
+
+### Arguments
+
+The function `extract_code_from_markdown` takes one argument:
+
+| Argument | Description | Type | Default Value |
+|-----------------------|----------------------------------------|-------------|-------------------|
+| markdown_content | The input markdown content as a string | str | N/A |
+
+
+## Function Explanation and Usage
+
+This function uses a regular expression to find all fenced code blocks in a Markdown string. The pattern `r"```(?:\w+\n)?(.*?)```"` matches strings that start and end with three backticks, optionally followed by a newline and then any number of any characters (the `.*?` part) until the first occurrence of another triple backtick set.
+
+Once we have the matches, we join all the code blocks into a single string, each block separated by a newline.
+
+The method's functionality is particularly useful when we need to extract code blocks from markdown content for secondary processing, such as syntax highlighting or execution in a different environment.
+
+### Usage Examples
+
+Below are three examples of how you might use this function:
+
+#### Example 1:
+
+Extracting code blocks from a simple markdown string.
+
+```python
+import re
+from swarms.utils import extract_code_from_markdown
+
+markdown_string = '''# Example
+This is an example of a code block:
+```python
+print("Hello World!")
+``` '''
+print(extract_code_from_markdown(markdown_string))
+```
+
+#### Example 2:
+
+Extracting code blocks from a markdown file.
+
+```python
+import re
+
+def extract_code_from_markdown(markdown_content: str) -> str:
+ pattern = r"```(?:\w+\n)?(.*?)```"
+ matches = re.findall(pattern, markdown_content, re.DOTALL)
+ return "\n".join(code.strip() for code in matches)
+
+# Assume that 'example.md' contains multiple code blocks
+with open('example.md', 'r') as file:
+ markdown_content = file.read()
+print(extract_code_from_markdown(markdown_content))
+```
+
+#### Example 3:
+
+Using the function in a pipeline to extract and then analyze code blocks.
+
+```python
+import re
+
+def extract_code_from_markdown(markdown_content: str) -> str:
+ pattern = r"```(?:\w+\n)?(.*?)```"
+ matches = re.findall(pattern, markdown_content, re.DOTALL)
+ return "\n".join(code.strip() for code in matches)
+
+def analyze_code_blocks(code: str):
+ # Add your analysis logic here
+ pass
+
+# Assume that 'example.md' contains multiple code blocks
+with open('example.md', 'r') as file:
+ markdown_content = file.read()
+code_blocks = extract_code_from_markdown(markdown_content)
+analyze_code_blocks(code_blocks)
+```
+
+## Conclusion
+
+This concludes the detailed documentation of the `extract_code_from_markdown` function from the swarms.utils module. With this documentation, you should be able to understand the function's purpose, how it works, its parameters, and see examples of how to use it effectively.
diff --git a/docs/swarms/utils/find_image_path.md b/docs/swarms/utils/find_image_path.md
new file mode 100644
index 00000000..59c9c127
--- /dev/null
+++ b/docs/swarms/utils/find_image_path.md
@@ -0,0 +1,94 @@
+# find_image_path
+
+Firstly, we will divide this documentation into multiple sections.
+
+# Overview
+The module **swarms.utils** has the main goal of providing necessary utility functions that are crucial during the creation of the swarm intelligence frameworks. These utility functions can include common operations such as handling input-output operations for files, handling text parsing, and handling basic mathematical computations necessary during the creation of swarm intelligence models.
+
+The current function `find_image_path` in the module is aimed at extracting an image path from a given text document.
+
+# Function Detailed Explanation
+
+## Definition
+The function `find_image_path` takes a singular argument as an input:
+
+```python
+def find_image_path(text):
+ # function body
+```
+
+## Parameter
+The parameter `text` in the function is a string that represents the document or text from which the function is trying to extract all paths to the images present. The function scans the given text, looking for absolute or relative paths to image files (.png, .jpg, .jpeg) on the disk.
+
+| Parameter Name | Data Type | Default Value | Description |
+|:--------------:|:---------:|:-------------:|:--------:|
+| `text` | `str` | - | The text content to scan for image paths |
+
+## Return Value
+
+The return value of the function `find_image_path` is a string that represents the longest existing image path extracted from the input text. If no image paths exist within the text, the function returns `None`.
+
+
+| Return Value | Data Type | Description |
+|:------------:|:-----------:|:-----------:|
+| Path | `str` | Longest image path found in the text or `None` if no path found |
+
+# Function's Code
+
+The function `find_image_path` performs text parsing and pattern recognition to find image paths within the provided text. The function uses `regular expressions (re)` module to detect all potential paths.
+
+```python
+def find_image_path(text):
+ pattern = r"([A-Za-z]:\\[^:\n]*?\.(png|jpg|jpeg|PNG|JPG|JPEG))|(/[^:\n]*?\.(png|jpg|jpeg|PNG|JPG|JPEG))"
+ matches = [
+ match.group()
+ for match in re.finditer(pattern, text)
+ if match.group()
+ ]
+ matches += [match.replace("\\", "") for match in matches if match]
+ existing_paths = [
+ match for match in matches if os.path.exists(match)
+ ]
+ return max(existing_paths, key=len) if existing_paths else None
+```
+
+# Usage Examples
+
+Let's consider examples of how the function `find_image_path` can be used in different scenarios.
+
+**Example 1:**
+
+Consider the case where a text without any image path is provided.
+
+```python
+from swarms.utils import find_image_path
+
+text = "There are no image paths in this text"
+print(find_image_path(text)) # Outputs: None
+```
+
+**Example 2:**
+
+Consider the case where the text has multiple image paths.
+
+```python
+from swarms.utils import find_image_path
+
+text = "Here is an image path: /home/user/image1.png. Here is another one: C:\\Users\\User\\Documents\\image2.jpeg"
+print(find_image_path(text)) # Outputs: the longest image path (depends on your file system and existing files)
+```
+
+**Example 3:**
+
+In the final example, we consider a case where the text has an image path, but the file does not exist.
+
+```python
+from swarms.utils import find_image_path
+
+text = "Here is an image path: /home/user/non_existant.png"
+print(find_image_path(text)) # Outputs: None
+```
+
+# Closing Notes
+
+In conclusion, the `find_image_path` function is crucial in the `swarms.utils` module as it supports a key operation of identifying image paths within given input text. This allows users to automate the extraction of such data from larger documents/text. However, it's important to note the function returns only existing paths in your file system and only the longest if multiple exist.
diff --git a/docs/swarms/utils/limit_tokens_from_string.md b/docs/swarms/utils/limit_tokens_from_string.md
new file mode 100644
index 00000000..b096ebad
--- /dev/null
+++ b/docs/swarms/utils/limit_tokens_from_string.md
@@ -0,0 +1,82 @@
+# limit_tokens_from_string
+
+## Introduction
+The `Swarms.utils` library contains utility functions used across codes that handle machine learning and other operations. The `Swarms.utils` library includes a notable function named `limit_tokens_from_string()`. This function particularly limits the number of tokens in a given string.
+
+# Function: limit_tokens_from_string()
+Within the `Swarms.utils` library, there is a method `limit_tokens_from_string(string: str, model: str = "gpt-4", limit: int = 500) -> str:`
+
+## Description
+The function `limit_tokens_from_string()` limits the number of tokens in a given string based on the specified threshold. It is primarily useful when you are handling large text data and need to chunk or limit your text to a certain length. Limiting token length could be useful in various scenarios such as when working with data with limited computational resources, or when dealing with models that accept a specific maximum limit of text.
+
+## Parameters
+
+| Parameter | Type | Default Value | Description
+| :-----------| :----------- | :------------ | :------------|
+| `string` | `str` | `None` | The input string from which the tokens need to be limited. |
+| `model` | `str` | `"gpt-4"` | The model used to encode and decode the token. The function defaults to `gpt-4` but you can specify any model supported by `tiktoken`. If a model is not found, it falls back to use `gpt2` |
+| `limit` | `int` | `500` | The limit up to which the tokens have to be sliced. Default limit is 500.|
+
+## Returns
+
+| Return | Type | Description
+| :-----------| :----------- | :------------
+| `out` | `str` | A string that is constructed back from the encoded tokens that have been limited to a count of `limit` |
+
+## Method Detail and Usage Examples
+
+The method `limit_tokens_from_string()` takes in three parameters - `string`, `model`, and `limit`.
+
+
+First, it tries to get the encoding for the model specified in the `model` argument using `tiktoken.encoding_for_model(model)`. In case the specified model is not found, the function uses `gpt2` model encoding as a fallback.
+
+Next, the input `string` is tokenized using the `encode` method on the `encoding` tensor. This results in the `encoded` tensor.
+
+Then, the function slices the `encoded` tensor to get the first `limit` number of tokens.
+
+Finally, the function converts back the tokens into the string using the `decode` method of the `encoding` tensor. The resulting string `out` is returned.
+
+### Example 1:
+
+```python
+from swarms.utils import limit_tokens_from_string
+
+# longer input string
+string = "This is a very long string that needs to be tokenized. This string might exceed the maximum token limit, so it will need to be truncated."
+
+# lower token limit
+limit = 10
+
+output = limit_tokens_from_string(string, limit=limit)
+```
+
+### Example 2:
+
+```python
+from swarms.utils import limit_tokens_from_string
+
+# longer input string with different model
+string = "This string will be tokenized using gpt2 model. If the string is too long, it will be truncated."
+
+# model
+model = "gpt2"
+
+output = limit_tokens_from_string(string, model=model)
+```
+
+### Example 3:
+
+```python
+from swarms.utils import limit_tokens_from_string
+
+# try with a random model string
+string = "In case the method does not find the specified model, it will fall back to gpt2 model."
+
+# model
+model = "gpt-4"
+
+output = limit_tokens_from_string(string, model=model)
+```
+
+**Note:** If specifying a model not supported by `tiktoken` intentionally, it will fall back to `gpt2` model for encoding.
+
diff --git a/docs/swarms/utils/load_model_torch.md b/docs/swarms/utils/load_model_torch.md
new file mode 100644
index 00000000..ddcd7ee6
--- /dev/null
+++ b/docs/swarms/utils/load_model_torch.md
@@ -0,0 +1,102 @@
+# load_model_torch
+
+# load_model_torch: Utility Function Documentation
+
+## Introduction:
+
+`load_model_torch` is a utility function in the `swarms.utils` library that is designed to load a saved PyTorch model and move it to the designated device. It provides flexibility allowing the user to specify the model file location, the device where the loaded model should be moved to, whether to strictly enforce the keys in the state dictionary to match the keys returned by the model's `state_dict()`, and many more.
+
+Moreover, if the saved model file only contains the state dictionary, but not the model architecture, you can pass the model architecture as an argument.
+
+## Function Definition and Parameters:
+
+```python
+def load_model_torch(
+ model_path: str = None,
+ device: torch.device = None,
+ model: nn.Module = None,
+ strict: bool = True,
+ map_location=None,
+ *args,
+ **kwargs,
+) -> nn.Module:
+```
+
+The following table describes the parameters in detail:
+
+| Name | Type | Default Value | Description |
+| ------ | ------ | ------------- | ------------|
+| model_path | str | None | A string specifying the path to the saved model file on disk. _Required_ |
+| device | torch.device | None | A `torch.device` object that specifies the target device for the loaded model. If not provided, the function checks for the availability of a GPU and uses it if available. If not, it defaults to CPU. |
+| model | nn.Module | None | An instance of `torch.nn.Module` representing the model's architecture. This parameter is required if the model file only contains the model's state dictionary and not the model architecture. |
+| strict | bool | True | A boolean that determines whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function. If set to `True`, the function will raise a KeyError when the state dictionary and `state_dict()` keys do not match. |
+| map_location | callable | None | A function to remap the storage locations of the loaded model's parameters. Useful for loading models saved on a device type that is different from the current one. |
+| *args, **kwargs | - | - | Additional arguments and keyword arguments to be passed to `torch.load`.
+
+Returns:
+
+- `torch.nn.Module` - The loaded model after moving it to the desired device.
+
+Raises:
+
+- `FileNotFoundError` - If the saved model file is not found at the specified path.
+- `RuntimeError` - If there was an error while loading the model.
+
+## Example of Usage:
+
+This function can be used directly inside your code as shown in the following examples:
+
+### Example 1:
+Loading a model without specifying a device results in the function choosing the most optimal available device automatically.
+
+```python
+from swarms.utils import load_model_torch
+import torch.nn as nn
+
+# Assume `mymodel.pth` is in the current directory
+model_path = "./mymodel.pth"
+
+# Define your model architecture if the model file only contains state dict
+class MyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(10, 2)
+
+ def forward(self, x):
+ return self.linear(x)
+
+model = MyModel()
+
+# Load the model
+loaded_model = load_model_torch(model_path, model=model)
+
+# Now you can use the loaded model for prediction or further training
+```
+### Example 2:
+Explicitly specifying a device.
+
+```python
+# Assume `mymodel.pth` is in the current directory
+model_path = "./mymodel.pth"
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# Load the model
+loaded_model = load_model_torch(model_path, device=device)
+```
+
+### Example 3:
+Using a model file that contains only the state dictionary, not the model architecture.
+
+```python
+# Assume `mymodel_state_dict.pth` is in the current directory
+model_path = "./mymodel_state_dict.pth"
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# Define your model architecture
+model = MyModel()
+
+# Load the model
+loaded_model = load_model_torch(model_path, device=device, model=model)
+```
+
+This gives you an insight on how to use `load_model_torch` utility function from `swarms.utils` library efficiently. Always remember to pass the model path argument while the other arguments can be optional based on your requirements. Furthermore, handle exceptions properly for smooth functioning of your PyTorch related projects.
diff --git a/docs/swarms/utils/math_eval.md b/docs/swarms/utils/math_eval.md
index 21fd62cc..691089f8 100644
--- a/docs/swarms/utils/math_eval.md
+++ b/docs/swarms/utils/math_eval.md
@@ -1,99 +1,78 @@
-# Math Evaluation Decorator Documentation
+# math_eval
-## Introduction
-The Math Evaluation Decorator is a utility function that helps you compare the output of two functions, `func1` and `func2`, when given the same input. This decorator is particularly useful for validating whether a generated function produces the same results as a ground truth function. This documentation provides a detailed explanation of the Math Evaluation Decorator, its purpose, usage, and examples.
-## Purpose
-The Math Evaluation Decorator serves the following purposes:
-1. To compare the output of two functions, `func1` and `func2`, when given the same input.
-2. To log any errors that may occur during the evaluation.
-3. To provide a warning if the outputs of `func1` and `func2` do not match.
+The `math_eval` function is a python decorator that wraps around a function to run two functions on the same inputs and compare their results. The decorator can be used for testing functions that are expected to have equivalent functionality, or in situations where two different methods are used to calculate or retrieve a value, and the results need to be compared.
-## Decorator Definition
-```python
-def math_eval(func1, func2):
- """Math evaluation decorator.
-
- Args:
- func1 (_type_): The first function to be evaluated.
- func2 (_type_): The second function to be evaluated.
-
- Example:
- >>> @math_eval(ground_truth, generated_func)
- >>> def test_func(x):
- >>> return x
- >>> result1, result2 = test_func(5)
- >>> print(f"Result from ground_truth: {result1}")
- >>> print(f"Result from generated_func: {result2}")
-
- """
-```
-
-### Parameters
-| Parameter | Type | Description |
-|-----------|--------|--------------------------------------------------|
-| `func1` | _type_ | The first function to be evaluated. |
-| `func2` | _type_ | The second function to be evaluated. |
+The `math_eval` function in this case accepts two functions as parameters: `func1` and `func2`, and returns a decorator. This returned decorator, when applied to a function, enhances that function to execute both `func1` and `func2`, and compare the results.
-## Usage
-The Math Evaluation Decorator is used as a decorator for a test function that you want to evaluate. Here's how to use it:
+This can be particularly useful in situations when you are implementing a new function and wants to compare its behavior and results with that of an existing one under the same set of input parameters. It also logs the results if they do not match which could be quite useful during the debug process.
-1. Define the two functions, `func1` and `func2`, that you want to compare.
+## Usage Example
-2. Create a test function and decorate it with `@math_eval(func1, func2)`.
+Let's say you have two functions: `ground_truth` and `generated_func`, that have similar functionalities or serve the same purpose. You are writing a new function called `test_func`, and you'd like to compare the results of `ground_truth` and `generated_func` when `test_func` is run. Here is how you would use the `math_eval` decorator:
-3. In the test function, provide the input(s) to both `func1` and `func2`.
-
-4. The decorator will compare the outputs of `func1` and `func2` when given the same input(s).
-
-5. Any errors that occur during the evaluation will be logged.
-
-6. If the outputs of `func1` and `func2` do not match, a warning will be generated.
-
-## Examples
-
-### Example 1: Comparing Two Simple Functions
```python
-# Define the ground truth function
-def ground_truth(x):
- return x * 2
-
-# Define the generated function
-def generated_func(x):
- return x - 10
-
-# Create a test function and decorate it
@math_eval(ground_truth, generated_func)
def test_func(x):
return x
-
-# Evaluate the test function with an input
result1, result2 = test_func(5)
-
-# Print the results
print(f"Result from ground_truth: {result1}")
print(f"Result from generated_func: {result2}")
```
-In this example, the decorator compares the outputs of `ground_truth` and `generated_func` when given the input `5`. If the outputs do not match, a warning will be generated.
+## Parameters
+
+| Parameter | Data Type | Description |
+| ---- | ---- | ---- |
+| func1 | Callable | The first function whose result you want to compare. |
+| func2 | Callable | The second function whose result you want to compare. |
+
+The data types for `func1` and `func2` cannot be specified as they can be any python function (or callable object). The decorator verifies that they are callable and exceptions are handled within the decorator function.
+
+## Return Values
-### Example 2: Handling Errors
-If an error occurs in either `func1` or `func2`, the decorator will log the error and set the result to `None`. This ensures that the evaluation continues even if one of the functions encounters an issue.
+The `math_eval` function does not return a direct value, since it is a decorator. When applied to a function, it alters the behavior of the wrapped function to return two values:
-## Additional Information and Tips
+1. `result1`: The result of running `func1` with the given input parameters.
+2. `result2`: The result of running `func2` with the given input parameters.
-- The Math Evaluation Decorator is a powerful tool for comparing the outputs of functions, especially when validating machine learning models or generated code.
+These two return values are provided in that order as a tuple.
-- Ensure that the functions `func1` and `func2` take the same input(s) to ensure a meaningful comparison.
+## Source Code
-- Regularly check the logs for any errors or warnings generated during the evaluation.
+Here's how to implement the `math_eval` decorator:
-- If the decorator logs a warning about mismatched outputs, investigate and debug the functions accordingly.
+```python
+import functools
+import logging
+
+def math_eval(func1, func2):
+ """Math evaluation decorator."""
-## References and Resources
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ result1 = func1(*args, **kwargs)
+ except Exception as e:
+ logging.error(f"Error in func1: {e}")
+ result1 = None
-- For more information on Python decorators, refer to the [Python Decorators Documentation](https://docs.python.org/3/glossary.html#term-decorator).
+ try:
+ result2 = func2(*args, **kwargs)
+ except Exception as e:
+ logging.error(f"Error in func2: {e}")
+ result2 = None
-- Explore advanced use cases of the Math Evaluation Decorator in your projects to ensure code correctness and reliability.
+ if result1 != result2:
+ logging.warning(
+ f"Outputs do not match: {result1} != {result2}"
+ )
-This comprehensive documentation explains the Math Evaluation Decorator, its purpose, usage, and examples. Use this decorator to compare the outputs of functions and validate code effectively.
\ No newline at end of file
+ return result1, result2
+
+ return wrapper
+
+ return decorator
+```
+Please note that the code is logging exceptions to facilitate debugging, but the actual processing and handling of the exception would depend on how you want your application to respond to exceptions. Therefore, you may want to customize the error handling depending upon your application's requirements.
diff --git a/docs/swarms/utils/metrics_decorator.md b/docs/swarms/utils/metrics_decorator.md
new file mode 100644
index 00000000..aeafe151
--- /dev/null
+++ b/docs/swarms/utils/metrics_decorator.md
@@ -0,0 +1,86 @@
+# metrics_decorator
+
+This documentation explains the use and functionality of the `metrics_decorator` function in the LLM (Large Language Models).
+
+The `metrics_decorator` function is a standard Python decorator that augments a specific function by wrapping extra functionality around it. It is commonly used for things like timing, logging or memoization.
+--
+The `metrics_decorator` in LLM is specially designed to measure and calculate three key performance metrics when generating language models:
+
+1. `Time to First Token`: Measures the elapsed time from the start of function execution until the generation of the first token.
+2. `Generation Latency`: It measures the total time taken for a complete run.
+3. `Throughput`: Calculates the rate of production of tokens per unit of time.
+
+```python
+def metrics_decorator(func: Callable):
+ """
+
+ Metrics decorator for LLM
+
+ Args:
+ func (Callable): The function to be decorated.
+
+ """
+
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ """
+ An inner function that wraps the decorated function. It calculates 'Time to First Token',
+ 'Generation Latency' and 'Throughput' metrics.
+
+ Args:
+ self : The object instance.
+ *args : Variable length argument list of the decorated function.
+ **kwargs : Arbitrary keyword arguments of the decorated function.
+ """
+
+ # Measure Time to First Token
+ start_time = time.time()
+ result = func(self, *args, **kwargs)
+ first_token_time = time.time()
+
+ # Measure Generation Latency
+ end_time = time.time()
+
+ # Calculate Throughput (assuming the function returns a list of tokens)
+ throughput = len(result) / (end_time - start_time)
+
+ return f"""
+ Time to First Token: {first_token_time - start_time}
+ Generation Latency: {end_time - start_time}
+ Throughput: {throughput}
+ """
+
+ return wrapper
+```
+## Example Usage
+Now let's discuss the usage of the `metrics_decorator` function with an example.
+
+Assuming that we have a language generation function called `text_generator()` that generates a list of tokens.
+
+```python
+@metrics_decorator
+def text_generator(self, text: str):
+ """
+ Args:
+ text (str): The input text.
+
+ Returns:
+ A list of tokens generated from the input text.
+ """
+ # language generation implementation goes here
+ return tokens
+
+# Instantiate the class and call the decorated function
+obj = ClassName()
+obj.text_generator("Hello, world!")
+```
+
+When the decorated `text_generator()` function is called, it will measure and return:
+
+- Time elapsed until the first token is generated.
+- The total execution time of the function.
+- The rate of tokens generation per unit time.
+
+This example provides a basic overview of how a function can be decorated with the `metrics_decorator`. The provided `func` argument could be any method from any class, as long as it complies with the structure defined in `metrics_decorator`. It is worth noting that the decorated function must return a list of tokens for the `Throughput` metric to work correctly.
+
+Remember, applying the `metrics_decorator` does not affect the original functionality of the decorated function, it just adds additional measurement and logging capabilities to it. It's a great utility for tracking and optimizing the performance of your language models.
diff --git a/docs/swarms/utils/pdf_to_text.md b/docs/swarms/utils/pdf_to_text.md
new file mode 100644
index 00000000..aecde1a9
--- /dev/null
+++ b/docs/swarms/utils/pdf_to_text.md
@@ -0,0 +1,71 @@
+# pdf_to_text
+
+## Introduction
+The function `pdf_to_text` is a Python utility for converting a PDF file into a string of text content. It leverages the `PyPDF2` library, an excellent Python library for processing PDF files. The function takes in a PDF file's path and reads its content, subsequently returning the extracted textual data.
+
+This function can be very useful when you want to extract textual information from PDF files automatically. For instance, when processing a large number of documents, performing textual analysis, or when you're dealing with text data that is only available in PDF format.
+
+## Class / Function Definition
+
+`pdf_to_text` is a standalone function defined as follows:
+
+```python
+def pdf_to_text(pdf_path: str) -> str:
+```
+
+## Parameters
+
+| Parameter | Type | Description |
+|:-:|---|---|
+| pdf_path | str | The path to the PDF file to be converted |
+
+## Returns
+
+| Return Value | Type | Description |
+|:-:|---|---|
+| text | str | The text extracted from the PDF file. |
+
+## Raises
+
+| Exception | Description |
+|---|---|
+| FileNotFoundError | If the PDF file is not found at the specified path. |
+| Exception | If there is an error in reading the PDF file. |
+
+## Function Description
+
+`pdf_to_text` utilises the `PdfReader` function from the `PyPDF2` library to read the PDF file. If the PDF file does not exist at the specified path or there was an error while reading the file, appropriate exceptions will be raised. It then iterates through each page in the PDF and uses the `extract_text` function to extract the text content from each page. These contents are then concatenated into a single variable and returned as the result.
+
+## Usage Examples
+
+To use this function, you first need to install the `PyPDF2` library. It can be installed via pip:
+
+```python
+!pip install pypdf2
+```
+
+Then, you should import the `pdf_to_text` function:
+
+```python
+from swarms.utils import pdf_to_text
+```
+
+Here is an example of how to use `pdf_to_text`:
+
+```python
+# Define the path to the pdf file
+pdf_path = 'sample.pdf'
+
+# Use the function to extract text
+text = pdf_to_text(pdf_path)
+
+# Print the extracted text
+print(text)
+```
+
+## Tips and Additional Information
+- Ensure that the PDF file path is valid and that the file exists at the specified location. If the file does not exist, a `FileNotFoundError` will be raised.
+- This function reads the text from the PDF. It does not handle images, graphical elements, or any non-text content.
+- If the PDF contains scanned images rather than textual data, the `extract_text` function may not be able to extract any text. In such cases, you would require OCR (Optical Character Recognition) tools to extract the text.
+- Be aware of the possibility that the output string might contain special characters or escape sequences because they were part of the PDF's content. You might need to clean the resulting text according to your requirements.
+- The function uses the PyPDF2 library to facilitate the PDF reading and text extraction. For any issues related to PDF manipulation, consult the [PyPDF2 library documentation](https://pythonhosted.org/PyPDF2/).
diff --git a/docs/swarms/utils/prep_torch_inference.md b/docs/swarms/utils/prep_torch_inference.md
new file mode 100644
index 00000000..68598fa8
--- /dev/null
+++ b/docs/swarms/utils/prep_torch_inference.md
@@ -0,0 +1,102 @@
+# prep_torch_inference
+
+```python
+def prep_torch_inference(
+ model_path: str = None,
+ device: torch.device = None,
+ *args,
+ **kwargs,
+):
+ """
+ Prepare a Torch model for inference.
+
+ Args:
+ model_path (str): Path to the model file.
+ device (torch.device): Device to run the model on.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ torch.nn.Module: The prepared model.
+ """
+ try:
+ model = load_model_torch(model_path, device)
+ model.eval()
+ return model
+ except Exception as e:
+ # Add error handling code here
+ print(f"Error occurred while preparing Torch model: {e}")
+ return None
+```
+This method is part of the 'swarms.utils' module. It accepts a model file path and a torch device as input and returns a model that is ready for inference.
+
+## Detailed Functionality
+
+The method loads a PyTorch model from the file specified by `model_path`. This model is then moved to the specified `device` if it is provided. Subsequently, the method sets the model to evaluation mode by calling `model.eval()`. This is a crucial step when preparing a model for inference, as certain layers like dropout or batch normalization behave differently during training vs during evaluation.
+In the case of any exception (e.g., the model file not found or the device unavailable), it prints an error message and returns `None`.
+
+## Parameters
+
+| Parameter | Type | Description | Default |
+|-----------|------|-------------|---------|
+| model_path | str | Path to the model file. | None |
+| device | torch.device | Device to run the model on. | None |
+| args | tuple | Additional positional arguments. | None |
+| kwargs | dict | Additional keyword arguments. | None |
+
+## Returns
+
+| Type | Description |
+|------|-------------|
+| torch.nn.Module | The prepared model ready for inference. Returns `None` if any exception occurs. |
+
+## Usage Examples
+
+Here are some examples of how you can use the `prep_torch_inference` method. Before that, you need to import the necessary modules as follows:
+
+```python
+import torch
+from swarms.utils import prep_torch_inference, load_model_torch
+```
+
+### Example 1: Load a model for inference on CPU
+
+```python
+model_path = "saved_model.pth"
+model = prep_torch_inference(model_path)
+
+if model is not None:
+ print("Model loaded successfully and is ready for inference.")
+else:
+ print("Failed to load the model.")
+```
+
+### Example 2: Load a model for inference on CUDA device
+
+```python
+model_path = "saved_model.pth"
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+model = prep_torch_inference(model_path, device)
+
+if model is not None:
+ print(f"Model loaded successfully on device {device} and is ready for inference.")
+else:
+ print("Failed to load the model.")
+```
+
+### Example 3: Load a model with additional arguments for `load_model_torch`
+
+```python
+model_path = "saved_model.pth"
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+# Suppose load_model_torch accepts an additional argument, map_location
+model = prep_torch_inference(model_path, device, map_location=device)
+
+if model is not None:
+ print(f"Model loaded successfully on device {device} and is ready for inference.")
+else:
+ print("Failed to load the model.")
+```
+
+Please note, you need to ensure the given model path does exist and the device is available on your machine, else `prep_torch_inference` method will return `None`. Depending on the complexity and size of your models, loading them onto a specific device might take a while. So it's important that you take this into consideration when designing your machine learning workflows.
diff --git a/docs/swarms/utils/print_class_parameters.md b/docs/swarms/utils/print_class_parameters.md
new file mode 100644
index 00000000..3c09578f
--- /dev/null
+++ b/docs/swarms/utils/print_class_parameters.md
@@ -0,0 +1,110 @@
+# print_class_parameters
+
+# Module Function Name: print_class_parameters
+
+The `print_class_parameters` function is a utility function developed to help developers and users alike in retrieving and printing the parameters of a class constructor in Python, either in standard output or returned as a dictionary if the `api_format` is set to `True`.
+
+This utility function utilizes the `inspect` module to fetch the signature of the class constructor and fetches the parameters from the obtained signature. The parameter values and their respective types are then outputted.
+
+This function allows developers to easily inspect and understand the class' constructor parameters without the need to individually go through the class structure. This eases the testing and debugging process for developers and users alike, aiding in generating more efficient and readable code.
+
+__Function Definition:__
+
+```python
+def print_class_parameters(cls, api_format: bool = False):
+```
+__Parameters:__
+
+| Parameter | Type | Description | Default value |
+|---|---|---|---|
+| cls | type | The Python class to inspect. | None |
+| api_format | bool | Flag to determine if the output should be returned in dictionary format (if set to True) or printed out (if set to False) | False |
+
+__Functionality and Usage:__
+
+Inside the `print_class_parameters` function, it starts by getting the signature of the constructor of the inputted class by invoking `inspect.signature(cls.__init__)`. It then extracts the parameters from the signature and stores it in the `params` variable.
+
+If the `api_format` argument is set to `True`, instead of printing the parameters and their types, it stores them inside a dictionary where each key-value pair is a parameter name and its type. It then returns this dictionary.
+
+If `api_format` is set to `False` or not set at all (defaulting to False), the function iterates over the parameters and prints the parameter name and its type. "self" parameters are excluded from the output as they are inherent to all class methods in Python.
+
+A possible exception that may occur during the execution of this function is during the invocation of the `inspect.signature()` function call. If the inputted class does not have an `__init__` method or any error occurs during the retrieval of the class constructor's signature, an exception will be triggered. In that case, an error message that includes the error details is printed out.
+
+__Usage and Examples:__
+
+Assuming the existence of a class:
+
+```python
+class Agent:
+ def __init__(self, x: int, y: int):
+ self.x = x
+ self.y = y
+```
+
+One could use `print_class_parameters` in its typical usage:
+
+```python
+print_class_parameters(Agent)
+```
+
+Results in:
+
+```
+Parameter: x, Type:
+Parameter: y, Type:
+```
+
+Or, with `api_format` set to `True`
+
+```python
+output = print_class_parameters(Agent, api_format=True)
+print(output)
+```
+
+Results in:
+
+```
+{'x': "", 'y': ""}
+```
+
+__Note:__
+
+The function `print_class_parameters` is not limited to custom classes. It can inspect built-in Python classes such as `list`, `dict`, and others. However, it is most useful when inspecting custom-defined classes that aren't inherently documented in Python or third-party libraries.
+
+__Source Code__
+
+```python
+def print_class_parameters(cls, api_format: bool = False):
+ """
+ Print the parameters of a class constructor.
+
+ Parameters:
+ cls (type): The class to inspect.
+
+ Example:
+ >>> print_class_parameters(Agent)
+ Parameter: x, Type:
+ Parameter: y, Type:
+ """
+ try:
+ # Get the parameters of the class constructor
+ sig = inspect.signature(cls.__init__)
+ params = sig.parameters
+
+ if api_format:
+ param_dict = {}
+ for name, param in params.items():
+ if name == "self":
+ continue
+ param_dict[name] = str(param.annotation)
+ return param_dict
+
+ # Print the parameters
+ for name, param in params.items():
+ if name == "self":
+ continue
+ print(f"Parameter: {name}, Type: {param.annotation}")
+
+ except Exception as e:
+ print(f"An error occurred while inspecting the class: {e}")
+```
diff --git a/mkdocs.yml b/mkdocs.yml
index fa2f6955..f02daa9a 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -61,10 +61,6 @@ nav:
- Docker Container Setup: "docker_setup.md"
- Swarms:
- Overview: "swarms/index.md"
- - swarms.swarms:
- - AbstractSwarm: "swarms/swarms/abstractswarm.md"
- - GodMode: "swarms/swarms/godmode.md"
- - Groupchat: "swarms/swarms/groupchat.md"
- swarms.workers:
- Overview: "swarms/workers/index.md"
- AbstractWorker: "swarms/workers/abstract_worker.md"
@@ -99,18 +95,33 @@ nav:
- ElevenLabsText2SpeechTool: "swarms/models/elevenlabs.md"
- OpenAITTS: "swarms/models/openai_tts.md"
- Gemini: "swarms/models/gemini.md"
+ - ZeroscopeTTV: "swarms/models/zeroscope.md"
- swarms.structs:
- Overview: "swarms/structs/overview.md"
- AutoScaler: "swarms/swarms/autoscaler.md"
- Agent: "swarms/structs/agent.md"
- SequentialWorkflow: 'swarms/structs/sequential_workflow.md'
- Conversation: "swarms/structs/conversation.md"
+ - AbstractSwarm: "swarms/swarms/abstractswarm.md"
+ - ModelParallelizer: "swarms/swarms/ModelParallelizer.md"
+ - Groupchat: "swarms/swarms/groupchat.md"
- swarms.memory:
- Weaviate: "swarms/memory/weaviate.md"
- PineconeDB: "swarms/memory/pinecone.md"
- PGVectorStore: "swarms/memory/pg.md"
- ShortTermMemory: "swarms/memory/short_term_memory.md"
- swarms.utils:
+ - pdf_to_text: "swarms/utils/pdf_to_text.md"
+ - load_model_torch: "swarms/utils/load_model_torch.md"
+ - metrics_decorator: "swarms/utils/metrics_decorator.md"
+ - prep_torch_inference: "swarms/utils/prep_torch_inference.md"
+ - find_image_path: "swarms/utils/find_image_path.md"
+ - print_class_parameters: "swarms/utils/print_class_parameters.md"
+ - extract_code_from_markdown: "swarms/utils/extract_code_from_markdown.md"
+ - check_device: "swarms/utils/check_device.md"
+ - display_markdown_message: "swarms/utils/display_markdown_message.md"
+ - phoenix_tracer: "swarms/utils/phoenix_tracer.md"
+ - limit_tokens_from_string: "swarms/utils/limit_tokens_from_string.md"
- math_eval: "swarms/utils/math_eval.md"
- Guides:
- Overview: "examples/index.md"
diff --git a/playground/agents/simple_agent.py b/playground/agents/simple_agent.py
index b30d93c8..dd46083b 100644
--- a/playground/agents/simple_agent.py
+++ b/playground/agents/simple_agent.py
@@ -1,26 +1,45 @@
-from swarms.agents.simple_agent import SimpleAgent
-from swarms.structs import Agent
-from swarms.models import OpenAIChat
+import os
-api_key = ""
+from dotenv import load_dotenv
-llm = OpenAIChat(
- openai_api_key=api_key,
- temperature=0.5,
+from swarms import (
+ OpenAIChat,
+ Conversation,
)
-# Initialize the agent
-agent = Agent(
- llm=llm,
- max_loops=5,
+conv = Conversation(
+ time_enabled=True,
)
+# Load the environment variables
+load_dotenv()
+
+# Get the API key from the environment
+api_key = os.environ.get("OPENAI_API_KEY")
+
+# Initialize the language model
+llm = OpenAIChat(openai_api_key=api_key, model_name="gpt-4")
+
+
+# Run the language model in a loop
+def interactive_conversation(llm, iters: int = 10):
+ conv = Conversation()
+ for i in range(iters):
+ user_input = input("User: ")
+ conv.add("user", user_input)
+ if user_input.lower() == "quit":
+ break
+ task = (
+ conv.return_history_as_string()
+ ) # Get the conversation history
+ out = llm(task)
+ conv.add("assistant", out)
+ print(
+ f"Assistant: {out}",
+ )
+ conv.display_conversation()
+ conv.export_conversation("conversation.txt")
-agent = SimpleAgent(
- name="Optimus Prime",
- agent=agent,
- # Memory
-)
-out = agent.run("Generate a 10,000 word blog on health and wellness.")
-print(out)
+# Replace with your LLM instance
+interactive_conversation(llm)
diff --git a/playground/demos/llm_with_conversation/main.py b/playground/demos/llm_with_conversation/main.py
index 2bb28b4b..a9e6c42a 100644
--- a/playground/demos/llm_with_conversation/main.py
+++ b/playground/demos/llm_with_conversation/main.py
@@ -4,7 +4,6 @@ from dotenv import load_dotenv
# Import the OpenAIChat model and the Agent struct
from swarms.models import OpenAIChat
-from swarms.structs import Agent
# Load the environment variables
load_dotenv()
diff --git a/playground/swarms/godmode.py b/playground/swarms/godmode.py
index f1269d98..46f71393 100644
--- a/playground/swarms/godmode.py
+++ b/playground/swarms/godmode.py
@@ -1,16 +1,33 @@
-from swarms.swarms import GodMode
-from swarms.models import OpenAIChat
+import os
-api_key = ""
+from dotenv import load_dotenv
-llm = OpenAIChat(openai_api_key=api_key)
+from swarms.models import Anthropic, Gemini, Mixtral, OpenAIChat
+from swarms.swarms import ModelParallelizer
+load_dotenv()
-llms = [llm, llm, llm]
+# API Keys
+anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
+openai_api_key = os.getenv("OPENAI_API_KEY")
+gemini_api_key = os.getenv("GEMINI_API_KEY")
-god_mode = GodMode(llms)
+# Initialize the models
+llm = OpenAIChat(openai_api_key=openai_api_key)
+anthropic = Anthropic(anthropic_api_key=anthropic_api_key)
+mixtral = Mixtral()
+gemini = Gemini(gemini_api_key=gemini_api_key)
+# Initialize the parallelizer
+llms = [llm, anthropic, mixtral, gemini]
+parallelizer = ModelParallelizer(llms)
+
+# Set the task
task = "Generate a 10,000 word blog on health and wellness."
-out = god_mode.run(task)
-god_mode.print_responses(task)
+# Run the task
+out = parallelizer.run(task)
+
+# Print the responses 1 by 1
+for i in range(len(out)):
+ print(f"Response from LLM {i}: {out[i]}")
diff --git a/playground/swarms/groupchat.py b/playground/swarms/groupchat.py
index f53257c7..b9ab5761 100644
--- a/playground/swarms/groupchat.py
+++ b/playground/swarms/groupchat.py
@@ -1,5 +1,5 @@
from swarms import OpenAI, Agent
-from swarms.swarms.groupchat import GroupChatManager, GroupChat
+from swarms.structs.groupchat import GroupChatManager, GroupChat
api_key = ""
diff --git a/pyproject.toml b/pyproject.toml
index d29c59ad..f7379c21 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
-version = "2.4.2"
+version = "3.1.3"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez "]
@@ -39,7 +39,7 @@ backoff = "2.2.1"
marshmallow = "3.19.0"
datasets = "2.10.1"
optimum = "1.15.0"
-diffusers = "0.17.1"
+diffusers = "*"
PyPDF2 = "3.0.1"
accelerate = "0.22.0"
sentencepiece = "0.1.98"
@@ -70,6 +70,7 @@ pgvector = "*"
qdrant-client = "*"
vllm = "*"
sentence-transformers = "*"
+peft = "*"
[tool.poetry.group.lint.dependencies]
diff --git a/requirements.txt b/requirements.txt
index a445ebe6..9944b616 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -40,7 +40,7 @@ albumentations
basicsr
termcolor==2.2.0
controlnet-aux
-diffusers==0.17.1
+diffusers
einops==0.7.0
imageio==2.25.1
opencv-python-headless==4.8.1.78
@@ -74,3 +74,4 @@ pgvector
qdrant-client
vllm
sentence-transformers
+peft
diff --git a/scripts/auto_tests_docs/auto_docs.py b/scripts/auto_tests_docs/auto_docs.py
new file mode 100644
index 00000000..5df0f63d
--- /dev/null
+++ b/scripts/auto_tests_docs/auto_docs.py
@@ -0,0 +1,108 @@
+###### VERISON2
+import inspect
+import os
+import threading
+from zeta import OpenAIChat
+from scripts.auto_tests_docs.docs import DOCUMENTATION_WRITER_SOP
+from zeta.nn.modules._activations import (
+ AccurateGELUActivation,
+ ClippedGELUActivation,
+ FastGELUActivation,
+ GELUActivation,
+ LaplaceActivation,
+ LinearActivation,
+ MishActivation,
+ NewGELUActivation,
+ PytorchGELUTanh,
+ QuickGELUActivation,
+ ReLUSquaredActivation,
+)
+from zeta.nn.modules.dense_connect import DenseBlock
+from zeta.nn.modules.dual_path_block import DualPathBlock
+from zeta.nn.modules.feedback_block import FeedbackBlock
+from zeta.nn.modules.highway_layer import HighwayLayer
+from zeta.nn.modules.multi_scale_block import MultiScaleBlock
+from zeta.nn.modules.recursive_block import RecursiveBlock
+from dotenv import load_dotenv
+
+load_dotenv()
+
+api_key = os.getenv("OPENAI_API_KEY")
+
+model = OpenAIChat(
+ model_name="gpt-4",
+ openai_api_key=api_key,
+ max_tokens=4000,
+)
+
+
+def process_documentation(cls):
+ """
+ Process the documentation for a given class using OpenAI model and save it in a Markdown file.
+ """
+ doc = inspect.getdoc(cls)
+ source = inspect.getsource(cls)
+ input_content = (
+ "Class Name:"
+ f" {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource"
+ f" Code:\n{source}"
+ )
+ print(input_content)
+
+ # Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content)
+ processed_content = model(
+ DOCUMENTATION_WRITER_SOP(input_content, "zeta")
+ )
+
+ doc_content = f"# {cls.__name__}\n\n{processed_content}\n"
+
+ # Create the directory if it doesn't exist
+ dir_path = "docs/zeta/nn/modules"
+ os.makedirs(dir_path, exist_ok=True)
+
+ # Write the processed documentation to a Markdown file
+ file_path = os.path.join(dir_path, f"{cls.__name__.lower()}.md")
+ with open(file_path, "w") as file:
+ file.write(doc_content)
+
+
+def main():
+ classes = [
+ DenseBlock,
+ HighwayLayer,
+ MultiScaleBlock,
+ FeedbackBlock,
+ DualPathBlock,
+ RecursiveBlock,
+ PytorchGELUTanh,
+ NewGELUActivation,
+ GELUActivation,
+ FastGELUActivation,
+ QuickGELUActivation,
+ ClippedGELUActivation,
+ AccurateGELUActivation,
+ MishActivation,
+ LinearActivation,
+ LaplaceActivation,
+ ReLUSquaredActivation,
+ ]
+
+ threads = []
+ for cls in classes:
+ thread = threading.Thread(
+ target=process_documentation, args=(cls,)
+ )
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ print(
+ "Documentation generated in 'docs/zeta/nn/modules' directory."
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/auto_tests_docs/auto_docs_functions.py b/scripts/auto_tests_docs/auto_docs_functions.py
new file mode 100644
index 00000000..37bf376d
--- /dev/null
+++ b/scripts/auto_tests_docs/auto_docs_functions.py
@@ -0,0 +1,77 @@
+import inspect
+import os
+import sys
+import threading
+
+from dotenv import load_dotenv
+
+from scripts.auto_tests_docs.docs import DOCUMENTATION_WRITER_SOP
+from swarms import OpenAIChat
+
+load_dotenv()
+
+api_key = os.getenv("OPENAI_API_KEY")
+
+model = OpenAIChat(
+ model_name="gpt-4",
+ openai_api_key=api_key,
+ max_tokens=4000,
+)
+
+
+def process_documentation(item):
+ """
+ Process the documentation for a given function using OpenAI model and save it in a Markdown file.
+ """
+ doc = inspect.getdoc(item)
+ source = inspect.getsource(item)
+ input_content = (
+ f"Name: {item.__name__}\n\nDocumentation:\n{doc}\n\nSource"
+ f" Code:\n{source}"
+ )
+ print(input_content)
+
+ # Process with OpenAI model
+ processed_content = model(
+ DOCUMENTATION_WRITER_SOP(input_content, "swarms.utils")
+ )
+
+ doc_content = f"# {item.__name__}\n\n{processed_content}\n"
+
+ # Create the directory if it doesn't exist
+ dir_path = "docs/swarms/utils"
+ os.makedirs(dir_path, exist_ok=True)
+
+ # Write the processed documentation to a Markdown file
+ file_path = os.path.join(dir_path, f"{item.__name__.lower()}.md")
+ with open(file_path, "w") as file:
+ file.write(doc_content)
+
+
+def main():
+ # Gathering all functions from the swarms.utils module
+ functions = [
+ obj
+ for name, obj in inspect.getmembers(
+ sys.modules["swarms.utils"]
+ )
+ if inspect.isfunction(obj)
+ ]
+
+ threads = []
+ for func in functions:
+ thread = threading.Thread(
+ target=process_documentation, args=(func,)
+ )
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ print("Documentation generated in 'docs/swarms/utils' directory.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/auto_tests_docs/auto_tests.py b/scripts/auto_tests_docs/auto_tests.py
new file mode 100644
index 00000000..73a35c4f
--- /dev/null
+++ b/scripts/auto_tests_docs/auto_tests.py
@@ -0,0 +1,123 @@
+import inspect
+import os
+import re
+import threading
+from swarms import OpenAIChat
+from scripts.auto_tests_docs.docs import TEST_WRITER_SOP_PROMPT
+from zeta.nn.modules._activations import (
+ AccurateGELUActivation,
+ ClippedGELUActivation,
+ FastGELUActivation,
+ GELUActivation,
+ LaplaceActivation,
+ LinearActivation,
+ MishActivation,
+ NewGELUActivation,
+ PytorchGELUTanh,
+ QuickGELUActivation,
+ ReLUSquaredActivation,
+)
+from zeta.nn.modules.dense_connect import DenseBlock
+from zeta.nn.modules.dual_path_block import DualPathBlock
+from zeta.nn.modules.feedback_block import FeedbackBlock
+from zeta.nn.modules.highway_layer import HighwayLayer
+from zeta.nn.modules.multi_scale_block import MultiScaleBlock
+from zeta.nn.modules.recursive_block import RecursiveBlock
+from dotenv import load_dotenv
+
+load_dotenv()
+
+api_key = os.getenv("OPENAI_API_KEY")
+
+model = OpenAIChat(
+ model_name="gpt-4",
+ openai_api_key=api_key,
+ max_tokens=4000,
+)
+
+
+def extract_code_from_markdown(markdown_content: str):
+ """
+ Extracts code blocks from a Markdown string and returns them as a single string.
+
+ Args:
+ - markdown_content (str): The Markdown content as a string.
+
+ Returns:
+ - str: A single string containing all the code blocks separated by newlines.
+ """
+ # Regular expression for fenced code blocks
+ pattern = r"```(?:\w+\n)?(.*?)```"
+ matches = re.findall(pattern, markdown_content, re.DOTALL)
+
+ # Concatenate all code blocks separated by newlines
+ return "\n".join(code.strip() for code in matches)
+
+
+def create_test(cls):
+ """
+ Process the documentation for a given class using OpenAI model and save it in a Python file.
+ """
+ doc = inspect.getdoc(cls)
+ source = inspect.getsource(cls)
+ input_content = (
+ "Class Name:"
+ f" {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource"
+ f" Code:\n{source}"
+ )
+ print(input_content)
+
+ # Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content)
+ processed_content = model(
+ TEST_WRITER_SOP_PROMPT(input_content, "zeta", "zeta.nn")
+ )
+ processed_content = extract_code_from_markdown(processed_content)
+
+ doc_content = f"# {cls.__name__}\n\n{processed_content}\n"
+
+ # Create the directory if it doesn't exist
+ dir_path = "tests/nn/modules"
+ os.makedirs(dir_path, exist_ok=True)
+
+ # Write the processed documentation to a Python file
+ file_path = os.path.join(dir_path, f"{cls.__name__.lower()}.py")
+ with open(file_path, "w") as file:
+ file.write(doc_content)
+
+
+def main():
+ classes = [
+ DenseBlock,
+ HighwayLayer,
+ MultiScaleBlock,
+ FeedbackBlock,
+ DualPathBlock,
+ RecursiveBlock,
+ PytorchGELUTanh,
+ NewGELUActivation,
+ GELUActivation,
+ FastGELUActivation,
+ QuickGELUActivation,
+ ClippedGELUActivation,
+ AccurateGELUActivation,
+ MishActivation,
+ LinearActivation,
+ LaplaceActivation,
+ ReLUSquaredActivation,
+ ]
+
+ threads = []
+ for cls in classes:
+ thread = threading.Thread(target=create_test, args=(cls,))
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ print("Tests generated in 'docs/zeta/nn/modules' directory.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/auto_tests_docs/auto_tests_functions.py b/scripts/auto_tests_docs/auto_tests_functions.py
new file mode 100644
index 00000000..437ff3bd
--- /dev/null
+++ b/scripts/auto_tests_docs/auto_tests_functions.py
@@ -0,0 +1,85 @@
+import inspect
+import os
+import sys
+import threading
+
+from dotenv import load_dotenv
+
+from scripts.auto_tests_docs.docs import TEST_WRITER_SOP_PROMPT
+from swarms import OpenAIChat
+from swarms.utils.parse_code import extract_code_from_markdown
+from swarms.utils import (
+ extract_code_from_markdown,
+)
+
+load_dotenv()
+
+api_key = os.getenv("OPENAI_API_KEY")
+
+model = OpenAIChat(
+ model_name="gpt-4",
+ openai_api_key=api_key,
+ max_tokens=4000,
+)
+
+
+def process_documentation(item):
+ """
+ Process the documentation for a given function using OpenAI model and save it in a Markdown file.
+ """
+ doc = inspect.getdoc(item)
+ source = inspect.getsource(item)
+ input_content = (
+ f"Name: {item.__name__}\n\nDocumentation:\n{doc}\n\nSource"
+ f" Code:\n{source}"
+ )
+ # print(input_content)
+
+ # Process with OpenAI model
+ processed_content = model(
+ TEST_WRITER_SOP_PROMPT(
+ input_content, "swarms.utils", "swarms.utils"
+ )
+ )
+ processed_content = extract_code_from_markdown(processed_content)
+ print(processed_content)
+
+ doc_content = f"{processed_content}"
+
+ # Create the directory if it doesn't exist
+ dir_path = "tests/utils"
+ os.makedirs(dir_path, exist_ok=True)
+
+ # Write the processed documentation to a Markdown file
+ file_path = os.path.join(dir_path, f"{item.__name__.lower()}.py")
+ with open(file_path, "w") as file:
+ file.write(doc_content)
+
+
+def main():
+ # Gathering all functions from the swarms.utils module
+ functions = [
+ obj
+ for name, obj in inspect.getmembers(
+ sys.modules["swarms.utils"]
+ )
+ if inspect.isfunction(obj)
+ ]
+
+ threads = []
+ for func in functions:
+ thread = threading.Thread(
+ target=process_documentation, args=(func,)
+ )
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ print("Tests generated in 'tests/utils' directory.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/auto_tests_docs/docs.py b/scripts/auto_tests_docs/docs.py
new file mode 100644
index 00000000..01df9d71
--- /dev/null
+++ b/scripts/auto_tests_docs/docs.py
@@ -0,0 +1,201 @@
+def DOCUMENTATION_WRITER_SOP(
+ task: str,
+ module: str,
+):
+ documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the {module} code below follow the outline for the {module} library,
+ provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words,
+ provide many usage examples and note this is markdown docs, create the documentation for the code to document,
+ put the arguments and methods in a table in markdown to make it visually seamless
+
+ Now make the professional documentation for this code, provide the architecture and how the class works and why it works that way,
+ it's purpose, provide args, their types, 3 ways of usage examples, in examples show all the code like imports main example etc
+
+ BE VERY EXPLICIT AND THOROUGH, MAKE IT DEEP AND USEFUL
+
+ ########
+ Step 1: Understand the purpose and functionality of the module or framework
+
+ Read and analyze the description provided in the documentation to understand the purpose and functionality of the module or framework.
+ Identify the key features, parameters, and operations performed by the module or framework.
+ Step 2: Provide an overview and introduction
+
+ Start the documentation by providing a brief overview and introduction to the module or framework.
+ Explain the importance and relevance of the module or framework in the context of the problem it solves.
+ Highlight any key concepts or terminology that will be used throughout the documentation.
+ Step 3: Provide a class or function definition
+
+ Provide the class or function definition for the module or framework.
+ Include the parameters that need to be passed to the class or function and provide a brief description of each parameter.
+ Specify the data types and default values for each parameter.
+ Step 4: Explain the functionality and usage
+
+ Provide a detailed explanation of how the module or framework works and what it does.
+ Describe the steps involved in using the module or framework, including any specific requirements or considerations.
+ Provide code examples to demonstrate the usage of the module or framework.
+ Explain the expected inputs and outputs for each operation or function.
+ Step 5: Provide additional information and tips
+
+ Provide any additional information or tips that may be useful for using the module or framework effectively.
+ Address any common issues or challenges that developers may encounter and provide recommendations or workarounds.
+ Step 6: Include references and resources
+
+ Include references to any external resources or research papers that provide further information or background on the module or framework.
+ Provide links to relevant documentation or websites for further exploration.
+ Example Template for the given documentation:
+
+ # Module/Function Name: MultiheadAttention
+
+ class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
+ ```
+ Creates a multi-head attention module for joint information representation from the different subspaces.
+
+ Parameters:
+ - embed_dim (int): Total dimension of the model.
+ - num_heads (int): Number of parallel attention heads. The embed_dim will be split across num_heads.
+ - dropout (float): Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
+ - bias (bool): If specified, adds bias to input/output projection layers. Default: True.
+ - add_bias_kv (bool): If specified, adds bias to the key and value sequences at dim=0. Default: False.
+ - add_zero_attn (bool): If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.
+ - kdim (int): Total number of features for keys. Default: None (uses kdim=embed_dim).
+ - vdim (int): Total number of features for values. Default: None (uses vdim=embed_dim).
+ - batch_first (bool): If True, the input and output tensors are provided as (batch, seq, feature). Default: False.
+ - device (torch.device): If specified, the tensors will be moved to the specified device.
+ - dtype (torch.dtype): If specified, the tensors will have the specified dtype.
+ ```
+
+ def forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False):
+ ```
+ Forward pass of the multi-head attention module.
+
+ Parameters:
+ - query (Tensor): Query embeddings of shape (L, E_q) for unbatched input, (L, N, E_q) when batch_first=False, or (N, L, E_q) when batch_first=True.
+ - key (Tensor): Key embeddings of shape (S, E_k) for unbatched input, (S, N, E_k) when batch_first=False, or (N, S, E_k) when batch_first=True.
+ - value (Tensor): Value embeddings of shape (S, E_v) for unbatched input, (S, N, E_v) when batch_first=False, or (N, S, E_v) when batch_first=True.
+ - key_padding_mask (Optional[Tensor]): If specified, a mask indicating elements to be ignored in key for attention computation.
+ - need_weights (bool): If specified, returns attention weights in addition to attention outputs. Default: True.
+ - attn_mask (Optional[Tensor]): If specified, a mask preventing attention to certain positions.
+ - average_attn_weights (bool): If true, returns averaged attention weights per head. Otherwise, returns attention weights separately per head. Note that this flag only has an effect when need_weights=True. Default: True.
+ - is_causal (bool): If specified, applies a causal mask as the attention mask. Default: False.
+
+ Returns:
+ Tuple[Tensor, Optional[Tensor]]:
+ - attn_output (Tensor): Attention outputs of shape (L, E) for unbatched input, (L, N, E) when batch_first=False, or (N, L, E) when batch_first=True.
+ - attn_output_weights (Optional[Tensor]): Attention weights of shape (L, S) when unbatched or (N, L, S) when batched. Optional, only returned when need_weights=True.
+ ```
+
+ # Implementation of the forward pass of the attention module goes here
+
+ return attn_output, attn_output_weights
+
+ ```
+ # Usage example:
+
+ multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ attn_output, attn_output_weights = multihead_attn(query, key, value)
+ Note:
+
+ The above template includes the class or function definition, parameters, description, and usage example.
+ To replicate the documentation for any other module or framework, follow the same structure and provide the specific details for that module or framework.
+
+
+ ############# DOCUMENT THE FOLLOWING CODE ########
+ {task}
+ """
+ return documentation
+
+
+def TEST_WRITER_SOP_PROMPT(
+ task: str, module: str, path: str, *args, **kwargs
+):
+ TESTS_PROMPT = f"""
+
+ Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any
+ just write the best tests possible, the module is {module}, the file path is {path}
+
+
+ ######### TESTING GUIDE #############
+
+ # **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`**
+
+ 1. **Preparation**:
+ - Install pytest: `pip install pytest`.
+ - Structure your project so that tests are in a separate `tests/` directory.
+ - Name your test files with the prefix `test_` for pytest to recognize them.
+
+ 2. **Writing Basic Tests**:
+ - Use clear function names prefixed with `test_` (e.g., `test_check_value()`).
+ - Use assert statements to validate results.
+
+ 3. **Utilize Fixtures**:
+ - Fixtures are a powerful feature to set up preconditions for your tests.
+ - Use `@pytest.fixture` decorator to define a fixture.
+ - Pass fixture name as an argument to your test to use it.
+
+ 4. **Parameterized Testing**:
+ - Use `@pytest.mark.parametrize` to run a test multiple times with different inputs.
+ - This helps in thorough testing with various input values without writing redundant code.
+
+ 5. **Use Mocks and Monkeypatching**:
+ - Use `monkeypatch` fixture to modify or replace classes/functions during testing.
+ - Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code.
+
+ 6. **Exception Testing**:
+ - Test for expected exceptions using `pytest.raises(ExceptionType)`.
+
+ 7. **Test Coverage**:
+ - Install pytest-cov: `pip install pytest-cov`.
+ - Run tests with `pytest --cov=my_module` to get a coverage report.
+
+ 8. **Environment Variables and Secret Handling**:
+ - Store secrets and configurations in environment variables.
+ - Use libraries like `python-decouple` or `python-dotenv` to load environment variables.
+ - For tests, mock or set environment variables temporarily within the test environment.
+
+ 9. **Grouping and Marking Tests**:
+ - Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`).
+ - This allows for selectively running certain groups of tests.
+
+ 10. **Use Plugins**:
+ - Utilize the rich ecosystem of pytest plugins (e.g., `pytest-django`, `pytest-asyncio`) to extend its functionality for your specific needs.
+
+ 11. **Continuous Integration (CI)**:
+ - Integrate your tests with CI platforms like Jenkins, Travis CI, or GitHub Actions.
+ - Ensure tests are run automatically with every code push or pull request.
+
+ 12. **Logging and Reporting**:
+ - Use `pytest`'s inbuilt logging.
+ - Integrate with tools like `Allure` for more comprehensive reporting.
+
+ 13. **Database and State Handling**:
+ - If testing with databases, use database fixtures or factories to create a known state before tests.
+ - Clean up and reset state post-tests to maintain consistency.
+
+ 14. **Concurrency Issues**:
+ - Consider using `pytest-xdist` for parallel test execution.
+ - Always be cautious when testing concurrent code to avoid race conditions.
+
+ 15. **Clean Code Practices**:
+ - Ensure tests are readable and maintainable.
+ - Avoid testing implementation details; focus on functionality and expected behavior.
+
+ 16. **Regular Maintenance**:
+ - Periodically review and update tests.
+ - Ensure that tests stay relevant as your codebase grows and changes.
+
+ 17. **Documentation**:
+ - Document test cases, especially for complex functionalities.
+ - Ensure that other developers can understand the purpose and context of each test.
+
+ 18. **Feedback Loop**:
+ - Use test failures as feedback for development.
+ - Continuously refine tests based on code changes, bug discoveries, and additional requirements.
+
+ By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project.
+
+
+ ######### CREATE TESTS FOR THIS CODE: #######
+ {task}
+
+ """
+
+ return TESTS_PROMPT
diff --git a/scripts/auto_tests_docs/mkdocs_handler.py b/scripts/auto_tests_docs/mkdocs_handler.py
new file mode 100644
index 00000000..6cb0452b
--- /dev/null
+++ b/scripts/auto_tests_docs/mkdocs_handler.py
@@ -0,0 +1,31 @@
+import os
+
+
+def generate_file_list(directory, output_file):
+ """
+ Generate a list of files in a directory in the specified format and write it to a file.
+
+ Args:
+ directory (str): The directory to list the files from.
+ output_file (str): The file to write the output to.
+ """
+ with open(output_file, "w") as f:
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if file.endswith(".md"):
+ # Remove the directory from the file path and replace slashes with dots
+ file_path = (
+ os.path.join(root, file)
+ .replace(directory + "/", "")
+ .replace("/", ".")
+ )
+ # Remove the file extension
+ file_name, _ = os.path.splitext(file)
+ # Write the file name and path to the output file
+ f.write(
+ f'- {file_name}: "swarms/utils/{file_path}"\n'
+ )
+
+
+# Use the function to generate the file list
+generate_file_list("docs/swarms/utils", "file_list.txt")
diff --git a/scripts/auto_tests_docs/update_mkdocs.py b/scripts/auto_tests_docs/update_mkdocs.py
new file mode 100644
index 00000000..dfde53cb
--- /dev/null
+++ b/scripts/auto_tests_docs/update_mkdocs.py
@@ -0,0 +1,64 @@
+import yaml
+
+
+def update_mkdocs(
+ class_names,
+ base_path="docs/zeta/nn/modules",
+ mkdocs_file="mkdocs.yml",
+):
+ """
+ Update the mkdocs.yml file with new documentation links.
+
+ Args:
+ - class_names: A list of class names for which documentation is generated.
+ - base_path: The base path where documentation Markdown files are stored.
+ - mkdocs_file: The path to the mkdocs.yml file.
+ """
+ with open(mkdocs_file, "r") as file:
+ mkdocs_config = yaml.safe_load(file)
+
+ # Find or create the 'zeta.nn.modules' section in 'nav'
+ zeta_modules_section = None
+ for section in mkdocs_config.get("nav", []):
+ if "zeta.nn.modules" in section:
+ zeta_modules_section = section["zeta.nn.modules"]
+ break
+
+ if zeta_modules_section is None:
+ zeta_modules_section = {}
+ mkdocs_config["nav"].append(
+ {"zeta.nn.modules": zeta_modules_section}
+ )
+
+ # Add the documentation paths to the 'zeta.nn.modules' section
+ for class_name in class_names:
+ doc_path = f"{base_path}/{class_name.lower()}.md"
+ zeta_modules_section[class_name] = doc_path
+
+ # Write the updated content back to mkdocs.yml
+ with open(mkdocs_file, "w") as file:
+ yaml.safe_dump(mkdocs_config, file, sort_keys=False)
+
+
+# Example usage
+classes = [
+ "DenseBlock",
+ "HighwayLayer",
+ "MultiScaleBlock",
+ "FeedbackBlock",
+ "DualPathBlock",
+ "RecursiveBlock",
+ "PytorchGELUTanh",
+ "NewGELUActivation",
+ "GELUActivation",
+ "FastGELUActivation",
+ "QuickGELUActivation",
+ "ClippedGELUActivation",
+ "AccurateGELUActivation",
+ "MishActivation",
+ "LinearActivation",
+ "LaplaceActivation",
+ "ReLUSquaredActivation",
+]
+
+update_mkdocs(classes)
diff --git a/scripts/code_quality.sh b/scripts/code_quality.sh
index 90153258..e3afec13 100755
--- a/scripts/code_quality.sh
+++ b/scripts/code_quality.sh
@@ -1,19 +1,19 @@
#!/bin/bash
-# Navigate to the directory containing the 'swarms' folder
+# Navigate to the directory containing the 'tests' folder
# cd /path/to/your/code/directory
# Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i)
-# on all Python files (*.py) under the 'swarms' directory.
-autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes swarms/
+# on all Python files (*.py) under the 'tests' directory.
+autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes zeta/
# Run black with default settings, since black does not have an aggressiveness level.
-# Black will format all Python files it finds in the 'swarms' directory.
-black --experimental-string-processing swarms/
+# Black will format all Python files it finds in the 'tests' directory.
+black --experimental-string-processing zeta/
-# Run ruff on the 'swarms' directory.
+# Run ruff on the 'tests' directory.
# Add any additional flags if needed according to your version of ruff.
-ruff --unsafe_fix
+ruff zeta/ --fix
# YAPF
-yapf --recursive --in-place --verbose --style=google --parallel swarms
+yapf --recursive --in-place --verbose --style=google --parallel tests
diff --git a/scripts/del_pycache.sh b/scripts/del_pycache.sh
new file mode 100755
index 00000000..dc8e7d2b
--- /dev/null
+++ b/scripts/del_pycache.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+# Find and delete all __pycache__ directories
+find . -type d -name "__pycache__" -exec rm -r {} +
+
+# Find and delete all .pyc files
+find . -type f -name "*.pyc" -delete
\ No newline at end of file
diff --git a/scripts/delete_pycache.sh b/scripts/delete_pycache.sh
deleted file mode 100644
index db11f239..00000000
--- a/scripts/delete_pycache.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/bin/bash
-
-# Find all __pycache__ directories and delete them
-find . -type d -name "__pycache__" -exec rm -rf {} +
\ No newline at end of file
diff --git a/scripts/test_name.sh b/scripts/test_name.sh
index cdc6a013..4123f870 100755
--- a/scripts/test_name.sh
+++ b/scripts/test_name.sh
@@ -4,5 +4,6 @@ do
dir=$(dirname "$file")
if [[ $filename != test_* ]]; then
mv "$file" "$dir/test_$filename"
+ printf "\e[1;34mRenamed: \e[0m$file \e[1;32mto\e[0m $dir/test_$filename\n"
fi
done
\ No newline at end of file
diff --git a/swarm_network.py b/swarm_network.py
new file mode 100644
index 00000000..de9c53b6
--- /dev/null
+++ b/swarm_network.py
@@ -0,0 +1,46 @@
+import os
+
+from dotenv import load_dotenv
+
+# Import the OpenAIChat model and the Agent struct
+from swarms import OpenAIChat, Agent, SwarmNetwork
+
+# Load the environment variables
+load_dotenv()
+
+# Get the API key from the environment
+api_key = os.environ.get("OPENAI_API_KEY")
+
+# Initialize the language model
+llm = OpenAIChat(
+ temperature=0.5,
+ openai_api_key=api_key,
+)
+
+## Initialize the workflow
+agent = Agent(llm=llm, max_loops=1, agent_name="Social Media Manager")
+agent2 = Agent(llm=llm, max_loops=1, agent_name=" Product Manager")
+agent3 = Agent(llm=llm, max_loops=1, agent_name="SEO Manager")
+
+
+# Load the swarmnet with the agents
+swarmnet = SwarmNetwork(
+ agents=[agent, agent2, agent3],
+)
+
+# List the agents in the swarm network
+out = swarmnet.list_agents()
+print(out)
+
+# Run the workflow on a task
+out = swarmnet.run_single_agent(
+ agent2.id, "Generate a 10,000 word blog on health and wellness."
+)
+print(out)
+
+
+# Run all the agents in the swarm network on a task
+out = swarmnet.run_many_agents(
+ "Generate a 10,000 word blog on health and wellness."
+)
+print(out)
diff --git a/swarms/__init__.py b/swarms/__init__.py
index f6f04205..d555cf4d 100644
--- a/swarms/__init__.py
+++ b/swarms/__init__.py
@@ -1,6 +1,4 @@
-from swarms.utils.disable_logging import disable_logging
-
-disable_logging()
+# disable_logging()
from swarms.agents import * # noqa: E402, F403
from swarms.swarms import * # noqa: E402, F403
diff --git a/swarms/agents/simple_agent.py b/swarms/agents/simple_agent.py
new file mode 100644
index 00000000..1c6d3126
--- /dev/null
+++ b/swarms/agents/simple_agent.py
@@ -0,0 +1,40 @@
+from swarms.structs.conversation import Conversation
+from swarms.models.base_llm import AbstractLLM
+
+
+# Run the language model in a loop for n iterations
+def SimpleAgent(
+ llm: AbstractLLM = None, iters: int = 10, *args, **kwargs
+):
+ """Simple agent conversation
+
+ Args:
+ llm (_type_): _description_
+ iters (int, optional): _description_. Defaults to 10.
+ """
+ try:
+ conv = Conversation(*args, **kwargs)
+ for i in range(iters):
+ user_input = input("User: ")
+ conv.add("user", user_input)
+ if user_input.lower() == "quit":
+ break
+ task = (
+ conv.return_history_as_string()
+ ) # Get the conversation history
+ out = llm(task)
+ conv.add("assistant", out)
+ print(
+ f"Assistant: {out}",
+ )
+ conv.display_conversation()
+ conv.export_conversation("conversation.txt")
+
+ except Exception as error:
+ print(f"[ERROR][SimpleAgentConversation] {error}")
+ raise error
+
+ except KeyboardInterrupt:
+ print("[INFO][SimpleAgentConversation] Keyboard interrupt")
+ conv.export_conversation("conversation.txt")
+ raise KeyboardInterrupt
diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py
index 40f9979c..0a553a16 100644
--- a/swarms/memory/qdrant.py
+++ b/swarms/memory/qdrant.py
@@ -1,4 +1,3 @@
-import subprocess
from typing import List
from httpx import RequestError
@@ -8,9 +7,6 @@ try:
except ImportError:
print("Please install the sentence-transformers package")
print("pip install sentence-transformers")
- print("pip install qdrant-client")
- subprocess.run(["pip", "install", "sentence-transformers"])
-
try:
from qdrant_client import QdrantClient
@@ -22,7 +18,6 @@ try:
except ImportError:
print("Please install the qdrant-client package")
print("pip install qdrant-client")
- subprocess.run(["pip", "install", "qdrant-client"])
class Qdrant:
diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py
index b66eb1d3..58701f64 100644
--- a/swarms/models/__init__.py
+++ b/swarms/models/__init__.py
@@ -9,16 +9,17 @@ from swarms.models.openai_models import (
OpenAIChat,
) # noqa: E402
-# from swarms.models.vllm import vLLM # noqa: E402
-# from swarms.models.zephyr import Zephyr # noqa: E402
+from swarms.models.vllm import vLLM # noqa: E402
+from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
from swarms.models.wizard_storytelling import (
WizardLLMStoryTeller,
) # noqa: E402
from swarms.models.mpt import MPT7B # noqa: E402
+from swarms.models.mixtral import Mixtral # noqa: E402
-# MultiModal Models
+################# MultiModal Models
from swarms.models.base_multimodal_model import (
BaseMultiModalModel,
) # noqa: E402
@@ -32,6 +33,7 @@ from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402
from swarms.models.openai_tts import OpenAITTS # noqa: E402
from swarms.models.gemini import Gemini # noqa: E402
from swarms.models.gigabind import Gigabind # noqa: E402
+from swarms.models.zeroscope import ZeroscopeTTV # noqa: E402
# from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3
@@ -39,6 +41,14 @@ from swarms.models.gigabind import Gigabind # noqa: E402
# from swarms.models.whisperx_model import WhisperX # noqa: E402
# from swarms.models.kosmos_two import Kosmos # noqa: E402
+from swarms.models.types import (
+ TextModality,
+ ImageModality,
+ AudioModality,
+ VideoModality,
+ MultimodalData,
+) # noqa: E402
+
__all__ = [
"AbstractLLM",
"Anthropic",
@@ -47,7 +57,7 @@ __all__ = [
"OpenAI",
"AzureOpenAI",
"OpenAIChat",
- # "Zephyr",
+ "Zephyr",
"BaseMultiModalModel",
"Idefics",
# "Kosmos",
@@ -62,8 +72,15 @@ __all__ = [
# "Dalle3",
# "DistilWhisperModel",
"GPT4VisionAPI",
- # "vLLM",
+ "vLLM",
"OpenAITTS",
"Gemini",
"Gigabind",
+ "Mixtral",
+ "ZeroscopeTTV",
+ "TextModality",
+ "ImageModality",
+ "AudioModality",
+ "VideoModality",
+ "MultimodalData",
]
diff --git a/swarms/models/base_ttv.py b/swarms/models/base_ttv.py
new file mode 100644
index 00000000..6ef959e8
--- /dev/null
+++ b/swarms/models/base_ttv.py
@@ -0,0 +1,115 @@
+from abc import abstractmethod
+from swarms.models.base_llm import AbstractLLM
+from diffusers.utils import export_to_video
+from typing import Optional, List
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+
+
+class BaseTextToVideo(AbstractLLM):
+ """BaseTextToVideo class represents prebuilt text-to-video models."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @abstractmethod
+ def run(self, *args, **kwargs):
+ pass
+
+ def __call__(
+ self,
+ task: Optional[str] = None,
+ img: Optional[str] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Performs forward pass on the input task and returns the path of the generated video.
+
+ Args:
+ task (str): The task to perform.
+
+ Returns:
+ str: The path of the generated video.
+ """
+ return self.run(task, img, *args, **kwargs)
+
+ def save_video_path(
+ self, video_path: Optional[str] = None, *args, **kwargs
+ ):
+ """Saves the generated video to the specified path.
+
+ Args:
+ video_path (Optional[str], optional): _description_. Defaults to None.
+
+ Returns:
+ str: The path of the generated video.
+ """
+ return export_to_video(video_path, *args, **kwargs)
+
+ def run_batched(
+ self,
+ tasks: List[str] = None,
+ imgs: List[str] = None,
+ *args,
+ **kwargs,
+ ):
+ # TODO: Implement batched inference
+ tasks = tasks or []
+ imgs = imgs or []
+ if len(tasks) != len(imgs):
+ raise ValueError(
+ "The number of tasks and images should be the same."
+ )
+ return [
+ self.run(task, img, *args, **kwargs)
+ for task, img in zip(tasks, imgs)
+ ]
+
+ def run_concurrent_batched(
+ self,
+ tasks: List[str] = None,
+ imgs: List[str] = None,
+ *args,
+ **kwargs,
+ ):
+ tasks = tasks or []
+ imgs = imgs or []
+ if len(tasks) != len(imgs):
+ raise ValueError(
+ "The number of tasks and images should be the same."
+ )
+ with ThreadPoolExecutor(max_workers=4) as executor:
+ loop = asyncio.get_event_loop()
+ tasks = [
+ loop.run_in_executor(
+ executor, self.run, task, img, *args, **kwargs
+ )
+ for task, img in zip(tasks, imgs)
+ ]
+ return loop.run_until_complete(asyncio.gather(*tasks))
+
+ # Run the model in async mode
+ def arun(
+ self,
+ task: Optional[str] = None,
+ img: Optional[str] = None,
+ *args,
+ **kwargs,
+ ):
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.run(task, img, *args, **kwargs)
+ )
+
+ def arun_batched(
+ self,
+ tasks: List[str] = None,
+ imgs: List[str] = None,
+ *args,
+ **kwargs,
+ ):
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.run_batched(tasks, imgs, *args, **kwargs)
+ )
diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py
index 40f63418..6b225b49 100644
--- a/swarms/models/dalle3.py
+++ b/swarms/models/dalle3.py
@@ -18,8 +18,6 @@ from termcolor import colored
load_dotenv()
-# api_key = os.getenv("OPENAI_API_KEY")
-
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
diff --git a/swarms/models/diffusers_general.py b/swarms/models/diffusers_general.py
new file mode 100644
index 00000000..9d7ea250
--- /dev/null
+++ b/swarms/models/diffusers_general.py
@@ -0,0 +1 @@
+# Base implementation for the diffusers library
diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py
deleted file mode 100644
index e97fb496..00000000
--- a/swarms/models/fastvit.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import json
-import os
-from typing import List
-
-import timm
-import torch
-from PIL import Image
-from pydantic import BaseModel, StrictFloat, StrictInt, validator
-
-DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-# Load the classes for image classification
-with open(
- os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")
-) as f:
- FASTVIT_IMAGENET_1K_CLASSES = json.load(f)
-
-
-class ClassificationResult(BaseModel):
- class_id: List[StrictInt]
- confidence: List[StrictFloat]
-
- @validator("class_id", "confidence", pre=True, each_item=True)
- def check_list_contents(cls, v):
- assert isinstance(v, int) or isinstance(
- v, float
- ), "must be integer or float"
- return v
-
-
-class FastViT:
- """
- FastViT model for image classification
-
- Args:
- img (str): path to the input image
- confidence_threshold (float): confidence threshold for the model's predictions
-
- Returns:
- ClassificationResult: a pydantic BaseModel containing the class ids and confidences of the model's predictions
-
- Example:
- >>> fastvit = FastViT()
- >>> result = fastvit(img="path_to_image.jpg", confidence_threshold=0.5)
-
- To use, create a json file called: fast_vit_classes.json
- """
-
- def __init__(self):
- self.model = timm.create_model(
- "hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True
- ).to(DEVICE)
- data_config = timm.data.resolve_model_data_config(self.model)
- self.transforms = timm.data.create_transform(
- **data_config, is_training=False
- )
- self.model.eval()
-
- def __call__(
- self, img: str, confidence_threshold: float = 0.5
- ) -> ClassificationResult:
- """Classifies the input image and returns the top k classes and their probabilities"""
- img = Image.open(img).convert("RGB")
- img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
- with torch.no_grad():
- output = self.model(img_tensor)
- probabilities = torch.nn.functional.softmax(output, dim=1)
-
- # Get top k classes and their probabilities
- top_probs, top_classes = torch.topk(
- probabilities, k=FASTVIT_IMAGENET_1K_CLASSES
- )
-
- # Filter by confidence threshold
- mask = top_probs > confidence_threshold
- top_probs, top_classes = top_probs[mask], top_classes[mask]
-
- # Convert to Python lists and map class indices to labels if needed
- top_probs = top_probs.cpu().numpy().tolist()
- top_classes = top_classes.cpu().numpy().tolist()
-
- return ClassificationResult(
- class_id=top_classes, confidence=top_probs
- )
diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py
index f722bbb6..e02e53a5 100644
--- a/swarms/models/fuyu.py
+++ b/swarms/models/fuyu.py
@@ -49,11 +49,11 @@ class Fuyu(BaseMultiModalModel):
self.processor = FuyuProcessor(
image_processor=self.image_processor,
tokenizer=self.tokenizer,
- **kwargs,
)
self.model = FuyuForCausalLM.from_pretrained(
model_name,
device_map=device_map,
+ *args,
**kwargs,
)
@@ -62,7 +62,7 @@ class Fuyu(BaseMultiModalModel):
image_pil = Image.open(img)
return image_pil
- def run(self, text: str, img: str, *args, **kwargs):
+ def run(self, text: str = None, img: str = None, *args, **kwargs):
"""Run the pipeline
Args:
@@ -78,8 +78,6 @@ class Fuyu(BaseMultiModalModel):
text=text,
images=[img],
device=self.device_map,
- *args,
- **kwargs,
)
for k, v in model_inputs.items():
@@ -94,8 +92,6 @@ class Fuyu(BaseMultiModalModel):
text = self.processor.batch_decode(
output[:, -7:],
skip_special_tokens=True,
- *args,
- **kwargs,
)
return print(str(text))
except Exception as error:
diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py
index a0c5a86a..6bc4d810 100644
--- a/swarms/models/kosmos_two.py
+++ b/swarms/models/kosmos_two.py
@@ -8,7 +8,7 @@ import torchvision.transforms as T
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
-from swarms.models.base_multimodal_model import BaseMultimodalModel
+from swarms.models.base_multimodal_model import BaseMultiModalModel
# utils
@@ -18,7 +18,7 @@ def is_overlapping(rect1, rect2):
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
-class Kosmos(BaseMultimodalModel):
+class Kosmos(BaseMultiModalModel):
"""
Kosmos model by Yen-Chun Shieh
diff --git a/swarms/models/multion.py b/swarms/models/multion.py
deleted file mode 100644
index 14152faf..00000000
--- a/swarms/models/multion.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from swarms.models.base_llm import AbstractLLM
-
-
-try:
- import multion
-
-except ImportError:
- raise ImportError(
- "Cannot import multion, please install 'pip install'"
- )
-
-
-class MultiOn(AbstractLLM):
- """
- MultiOn is a wrapper for the Multion API.
-
- Args:
- **kwargs:
-
- Methods:
- run(self, task: str, url: str, *args, **kwargs)
-
- Example:
- >>> from swarms.models.multion import MultiOn
- >>> multion = MultiOn()
- >>> multion.run("Order chicken tendies", "https://www.google.com/")
- "Order chicken tendies. https://www.google.com/"
-
- """
-
- def __init__(self, **kwargs):
- super(MultiOn, self).__init__(**kwargs)
-
- def run(self, task: str, url: str, *args, **kwargs) -> str:
- """Run the multion model
-
- Args:
- task (str): _description_
- url (str): _description_
-
- Returns:
- str: _description_
- """
- response = multion.new_session({"input": task, "url": url})
- return response
-
- def generate_summary(
- self, task: str, url: str, *args, **kwargs
- ) -> str:
- """Generate a summary from the multion model
-
- Args:
- task (str): _description_
- url (str): _description_
-
- Returns:
- str: _description_
- """
- response = multion.new_session({"input": task, "url": url})
- return response
diff --git a/swarms/models/open_dalle.py b/swarms/models/open_dalle.py
new file mode 100644
index 00000000..b43d6c2e
--- /dev/null
+++ b/swarms/models/open_dalle.py
@@ -0,0 +1,66 @@
+from typing import Optional, Any
+
+import torch
+from diffusers import AutoPipelineForText2Image
+from swarms.models.base_multimodal_model import BaseMultiModalModel
+
+
+class OpenDalle(BaseMultiModalModel):
+ """OpenDalle model class
+
+ Attributes:
+ model_name (str): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1".
+ torch_dtype (torch.dtype): The torch data type to be used. Defaults to torch.float16.
+ device (str): The device to be used for computation. Defaults to "cuda".
+
+ Examples:
+ >>> from swarms.models.open_dalle import OpenDalle
+ >>> od = OpenDalle()
+ >>> od.run("A picture of a cat")
+
+ """
+
+ def __init__(
+ self,
+ model_name: str = "dataautogpt3/OpenDalleV1.1",
+ torch_dtype: Any = torch.float16,
+ device: str = "cuda",
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the OpenDalle model.
+
+ Args:
+ model_name (str, optional): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1".
+ torch_dtype (torch.dtype, optional): The torch data type to be used. Defaults to torch.float16.
+ device (str, optional): The device to be used for computation. Defaults to "cuda".
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+ """
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(
+ model_name, torch_dtype=torch_dtype, *args, **kwargs
+ ).to(device)
+
+ def run(self, task: Optional[str] = None, *args, **kwargs):
+ """Run the OpenDalle model
+
+ Args:
+ task (str, optional): The task to be performed. Defaults to None.
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ [type]: [description]
+ """
+ try:
+ if task is None:
+ raise ValueError("Task cannot be None")
+ if not isinstance(task, str):
+ raise TypeError("Task must be a string")
+ if len(task) < 1:
+ raise ValueError("Task cannot be empty")
+ return self.pipeline(task, *args, **kwargs).images[0]
+ except Exception as error:
+ print(f"[ERROR][OpenDalle] {error}")
+ raise error
diff --git a/swarms/models/sam.py b/swarms/models/sam.py
index 866c79ee..110d80b7 100644
--- a/swarms/models/sam.py
+++ b/swarms/models/sam.py
@@ -1,315 +1,107 @@
-import cv2
-import numpy as np
+import torch
from PIL import Image
-from transformers import (
- SamImageProcessor,
- SamModel,
- SamProcessor,
- pipeline,
-)
+import requests
+from transformers import SamModel, SamProcessor
+from typing import List
-try:
- import cv2
- import supervision as sv
-except ImportError:
- print("Please install supervision and cv")
+device = "cuda" if torch.cuda.is_available() else "cpu"
-from enum import Enum
-
-
-class FeatureType(Enum):
- """
- An enumeration to represent the types of features for mask adjustment in image
- segmentation.
- """
-
- ISLAND = "ISLAND"
- HOLE = "HOLE"
-
- @classmethod
- def list(cls):
- return list(map(lambda c: c.value, cls))
-
-
-def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray:
- """
- Vectorized computation of the Intersection over Union (IoU) for all pairs of masks.
-
- Parameters:
- masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
- number of masks, `H` is the height, and `W` is the width.
-
- Returns:
- np.ndarray: A 2D numpy array of shape `(N, N)` where each element `[i, j]` is
- the IoU between masks `i` and `j`.
-
- Raises:
- ValueError: If any of the masks is found to be empty.
- """
- if np.any(masks.sum(axis=(1, 2)) == 0):
- raise ValueError(
- "One or more masks are empty. Please filter out empty"
- " masks before using `compute_iou_vectorized` function."
- )
-
- masks_bool = masks.astype(bool)
- masks_flat = masks_bool.reshape(masks.shape[0], -1)
- intersection = np.logical_and(
- masks_flat[:, None], masks_flat[None, :]
- ).sum(axis=2)
- union = np.logical_or(
- masks_flat[:, None], masks_flat[None, :]
- ).sum(axis=2)
- iou_matrix = intersection / union
- return iou_matrix
-
-
-def mask_non_max_suppression(
- masks: np.ndarray, iou_threshold: float = 0.6
-) -> np.ndarray:
+class SAM:
"""
- Performs Non-Max Suppression on a set of masks by prioritizing larger masks and
- removing smaller masks that overlap significantly.
+ Class representing the SAM (Segmentation and Masking) model.
- When the IoU between two masks exceeds the specified threshold, the smaller mask
- (in terms of area) is discarded. This process is repeated for each pair of masks,
- effectively filtering out masks that are significantly overlapped by larger ones.
+ Args:
+ model_name (str): The name of the pre-trained SAM model. Default is "facebook/sam-vit-huge".
+ device (torch.device): The device to run the model on. Default is the current device.
+ input_points (List[List[int]]): The 2D location of a window in the image to segment. Default is [[450, 600]].
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
- Parameters:
- masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
- number of masks, `H` is the height, and `W` is the width.
- iou_threshold (float): The IoU threshold for determining significant overlap.
+ Attributes:
+ model_name (str): The name of the pre-trained SAM model.
+ device (torch.device): The device to run the model on.
+ input_points (List[List[int]]): The 2D location of a window in the image to segment.
+ model (SamModel): The pre-trained SAM model.
+ processor (SamProcessor): The processor for the SAM model.
- Returns:
- np.ndarray: A 3D numpy array of filtered masks.
- """
- num_masks = masks.shape[0]
- areas = masks.sum(axis=(1, 2))
- sorted_idx = np.argsort(-areas)
- keep_mask = np.ones(num_masks, dtype=bool)
- iou_matrix = compute_mask_iou_vectorized(masks)
- for i in range(num_masks):
- if not keep_mask[sorted_idx[i]]:
- continue
-
- overlapping_masks = iou_matrix[sorted_idx[i]] > iou_threshold
- overlapping_masks[sorted_idx[i]] = False
- keep_mask[sorted_idx] = np.logical_and(
- keep_mask[sorted_idx], ~overlapping_masks
- )
-
- return masks[keep_mask]
-
-
-def filter_masks_by_relative_area(
- masks: np.ndarray,
- minimum_area: float = 0.01,
- maximum_area: float = 1.0,
-) -> np.ndarray:
- """
- Filters masks based on their relative area within the total area of each mask.
-
- Parameters:
- masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
- number of masks, `H` is the height, and `W` is the width.
- minimum_area (float): The minimum relative area threshold. Must be between `0`
- and `1`.
- maximum_area (float): The maximum relative area threshold. Must be between `0`
- and `1`.
-
- Returns:
- np.ndarray: A 3D numpy array containing masks that fall within the specified
- relative area range.
+ Methods:
+ run(task=None, img=None, *args, **kwargs): Runs the SAM model on the given image and returns the segmentation scores and masks.
+ process_img(img: str = None, *args, **kwargs): Processes the input image and returns the processed image.
- Raises:
- ValueError: If `minimum_area` or `maximum_area` are outside the `0` to `1`
- range, or if `minimum_area` is greater than `maximum_area`.
"""
- if not (isinstance(masks, np.ndarray) and masks.ndim == 3):
- raise ValueError("Input must be a 3D numpy array.")
-
- if not (0 <= minimum_area <= 1) or not (0 <= maximum_area <= 1):
- raise ValueError(
- "`minimum_area` and `maximum_area` must be between 0"
- " and 1."
- )
-
- if minimum_area > maximum_area:
- raise ValueError(
- "`minimum_area` must be less than or equal to"
- " `maximum_area`."
- )
-
- total_area = masks.shape[1] * masks.shape[2]
- relative_areas = masks.sum(axis=(1, 2)) / total_area
- return masks[
- (relative_areas >= minimum_area)
- & (relative_areas <= maximum_area)
- ]
-
-
-def adjust_mask_features_by_relative_area(
- mask: np.ndarray,
- area_threshold: float,
- feature_type: FeatureType = FeatureType.ISLAND,
-) -> np.ndarray:
- """
- Adjusts a mask by removing small islands or filling small holes based on a relative
- area threshold.
-
- !!! warning
-
- Running this function on a mask with small islands may result in empty masks.
-
- Parameters:
- mask (np.ndarray): A 2D numpy array with shape `(H, W)`, where `H` is the
- height, and `W` is the width.
- area_threshold (float): Threshold for relative area to remove or fill features.
- feature_type (FeatureType): Type of feature to adjust (`ISLAND` for removing
- islands, `HOLE` for filling holes).
-
- Returns:
- np.ndarray: A 2D numpy array containing mask.
- """
- height, width = mask.shape
- total_area = width * height
+ def __init__(
+ self,
+ model_name: str = "facebook/sam-vit-huge",
+ device=device,
+ input_points: List[List[int]] = [[450, 600]],
+ *args,
+ **kwargs,
+ ):
+ self.model_name = model_name
+ self.device = device
+ self.input_points = input_points
- mask = np.uint8(mask * 255)
- operation = (
- cv2.RETR_EXTERNAL
- if feature_type == FeatureType.ISLAND
- else cv2.RETR_CCOMP
- )
- contours, _ = cv2.findContours(
- mask, operation, cv2.CHAIN_APPROX_SIMPLE
- )
+ self.model = SamModel.from_pretrained(
+ model_name, *args, **kwargs
+ ).to(device)
- for contour in contours:
- area = cv2.contourArea(contour)
- relative_area = area / total_area
- if relative_area < area_threshold:
- cv2.drawContours(
- image=mask,
- contours=[contour],
- contourIdx=-1,
- color=(
- 0 if feature_type == FeatureType.ISLAND else 255
- ),
- thickness=-1,
- )
- return np.where(mask > 0, 1, 0).astype(bool)
+ self.processor = SamProcessor.from_pretrained(model_name)
+ def run(self, task=None, img=None, *args, **kwargs):
+ """
+ Runs the SAM model on the given image and returns the segmentation scores and masks.
-def masks_to_marks(masks: np.ndarray) -> sv.Detections:
- """
- Converts a set of masks to a marks (sv.Detections) object.
+ Args:
+ task: The task to perform. Not used in this method.
+ img: The input image to segment.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
- Parameters:
- masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
- number of masks, `H` is the height, and `W` is the width.
+ Returns:
+ Tuple: A tuple containing the segmentation scores and masks.
- Returns:
- sv.Detections: An object containing the masks and their bounding box
- coordinates.
- """
- return sv.Detections(
- mask=masks, xyxy=sv.mask_to_xyxy(masks=masks)
- )
+ """
+ img = self.process_img(img)
+ # Specify the points of the mask to segment
+ input_points = [
+ self.input_points
+ ] # 2D location of a window in the image
-def refine_marks(
- marks: sv.Detections,
- maximum_hole_area: float = 0.01,
- maximum_island_area: float = 0.01,
- minimum_mask_area: float = 0.02,
- maximum_mask_area: float = 1.0,
-) -> sv.Detections:
- """
- Refines a set of masks by removing small islands and holes, and filtering by mask
- area.
+ # Preprocess the image
+ inputs = self.processor(
+ img, input_points=input_points, return_tensors="pt"
+ ).to(device)
- Parameters:
- marks (sv.Detections): An object containing the masks and their bounding box
- coordinates.
- maximum_hole_area (float): The maximum relative area of holes to be filled in
- each mask.
- maximum_island_area (float): The maximum relative area of islands to be removed
- from each mask.
- minimum_mask_area (float): The minimum relative area for a mask to be retained.
- maximum_mask_area (float): The maximum relative area for a mask to be retained.
+ with torch.no_grad():
+ outputs = self.model(**inputs) # noqa: E999
- Returns:
- sv.Detections: An object containing the masks and their bounding box
- coordinates.
- """
- result_masks = []
- for mask in marks.mask:
- mask = adjust_mask_features_by_relative_area(
- mask=mask,
- area_threshold=maximum_island_area,
- feature_type=FeatureType.ISLAND,
+ masks = self.processor.image_processor.post_process_masks(
+ outputs.pred_masks.cpu(),
+ inputs["original_sizes"].cpu(),
+ inputs["reshaped_input_sizes"].cpu(),
)
- mask = adjust_mask_features_by_relative_area(
- mask=mask,
- area_threshold=maximum_hole_area,
- feature_type=FeatureType.HOLE,
- )
- if np.any(mask):
- result_masks.append(mask)
- result_masks = np.array(result_masks)
- result_masks = filter_masks_by_relative_area(
- masks=result_masks,
- minimum_area=minimum_mask_area,
- maximum_area=maximum_mask_area,
- )
- return sv.Detections(
- mask=result_masks, xyxy=sv.mask_to_xyxy(masks=result_masks)
- )
-
-
-class SegmentAnythingMarkGenerator:
- """
- A class for performing image segmentation using a specified model.
+ scores = outputs.iou_scores
- Parameters:
- device (str): The device to run the model on (e.g., 'cpu', 'cuda').
- model_name (str): The name of the model to be loaded. Defaults to
- 'facebook/sam-vit-huge'.
- """
-
- def __init__(
- self,
- device: str = "cpu",
- model_name: str = "facebook/sam-vit-huge",
- ):
- self.model = SamModel.from_pretrained(model_name).to(device)
- self.processor = SamProcessor.from_pretrained(model_name)
- self.image_processor = SamImageProcessor.from_pretrained(
- model_name
- )
- self.pipeline = pipeline(
- task="mask-generation",
- model=self.model,
- image_processor=self.image_processor,
- device=device,
- )
+ return scores, masks
- def run(self, image: np.ndarray) -> sv.Detections:
+ def process_img(self, img: str = None, *args, **kwargs):
"""
- Generate image segmentation marks.
+ Processes the input image and returns the processed image.
- Parameters:
- image (np.ndarray): The image to be marked in BGR format.
+ Args:
+ img (str): The URL or file path of the input image.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
Returns:
- sv.Detections: An object containing the segmentation masks and their
- corresponding bounding box coordinates.
+ Image: The processed image.
+
"""
- image = Image.fromarray(
- cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- )
- outputs = self.pipeline(image, points_per_batch=64)
- masks = np.array(outputs["masks"])
- return masks_to_marks(masks=masks)
+ raw_image = Image.open(
+ requests.get(img, stream=True, *args, **kwargs).raw
+ ).convert("RGB")
+
+ return raw_image
diff --git a/swarms/models/types.py b/swarms/models/types.py
new file mode 100644
index 00000000..460d0ef7
--- /dev/null
+++ b/swarms/models/types.py
@@ -0,0 +1,28 @@
+from pydantic import BaseModel
+from typing import List, Optional
+
+
+class TextModality(BaseModel):
+ content: str
+
+
+class ImageModality(BaseModel):
+ url: str
+ alt_text: Optional[str]
+
+
+class AudioModality(BaseModel):
+ url: str
+ transcript: Optional[str]
+
+
+class VideoModality(BaseModel):
+ url: str
+ transcript: Optional[str]
+
+
+class MultimodalData(BaseModel):
+ text: Optional[List[TextModality]]
+ images: Optional[List[ImageModality]]
+ audio: Optional[List[AudioModality]]
+ video: Optional[List[VideoModality]]
diff --git a/swarms/models/zeroscope.py b/swarms/models/zeroscope.py
new file mode 100644
index 00000000..43a11206
--- /dev/null
+++ b/swarms/models/zeroscope.py
@@ -0,0 +1,103 @@
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.utils import export_to_video
+
+
+class ZeroscopeTTV:
+ """
+ ZeroscopeTTV class represents a zero-shot video generation model.
+
+ Args:
+ model_name (str): The name of the pre-trained model to use.
+ torch_dtype (torch.dtype): The torch data type to use for computations.
+ chunk_size (int): The size of chunks for forward chunking.
+ dim (int): The dimension along which to split the input for forward chunking.
+ num_inference_steps (int): The number of inference steps to perform.
+ height (int): The height of the video frames.
+ width (int): The width of the video frames.
+ num_frames (int): The number of frames in the video.
+
+ Attributes:
+ model_name (str): The name of the pre-trained model.
+ torch_dtype (torch.dtype): The torch data type used for computations.
+ chunk_size (int): The size of chunks for forward chunking.
+ dim (int): The dimension along which the input is split for forward chunking.
+ num_inference_steps (int): The number of inference steps to perform.
+ height (int): The height of the video frames.
+ width (int): The width of the video frames.
+ num_frames (int): The number of frames in the video.
+ pipe (DiffusionPipeline): The diffusion pipeline for video generation.
+
+ Methods:
+ forward(task: str = None, *args, **kwargs) -> str:
+ Performs forward pass on the input task and returns the path of the generated video.
+
+ Examples:
+ >>> from swarms.models
+ >>> zeroscope = ZeroscopeTTV()
+ >>> task = "A person is walking on the street."
+ >>> video_path = zeroscope(task)
+
+ """
+
+ def __init__(
+ self,
+ model_name: str = "cerspense/zeroscope_v2_576w",
+ torch_dtype=torch.float16,
+ chunk_size: int = 1,
+ dim: int = 1,
+ num_inference_steps: int = 40,
+ height: int = 320,
+ width: int = 576,
+ num_frames: int = 36,
+ *args,
+ **kwargs,
+ ):
+ self.model_name = model_name
+ self.torch_dtype = torch_dtype
+ self.chunk_size = chunk_size
+ self.dim = dim
+ self.num_inference_steps = num_inference_steps
+ self.height = height
+ self.width = width
+ self.num_frames = num_frames
+
+ self.pipe = DiffusionPipeline.from_pretrained(
+ model_name,
+ torch_dtype=torch_dtype,
+ *args,
+ )
+ self.pipe.scheduler = DPMSolverMultistepScheduler(
+ self.pipe.scheduler.config,
+ )
+ self.pipe_enable_model_cpu_offload()
+ self.pipe.enable_vae_slicing()
+ self.pipe.unet.enable_forward_chunking(
+ chunk_size=chunk_size, dim=dim
+ )
+
+ def run(self, task: str = None, *args, **kwargs):
+ """
+ Performs a forward pass on the input task and returns the path of the generated video.
+
+ Args:
+ task (str): The input task for video generation.
+
+ Returns:
+ str: The path of the generated video.
+ """
+ try:
+ video_frames = self.pipe(
+ task,
+ num_inference_steps=self.num_inference_steps,
+ height=self.height,
+ width=self.width,
+ num_frames=self.num_frames,
+ *args,
+ **kwargs,
+ ).frames
+ video_path = export_to_video(video_frames)
+ return video_path
+ except Exception as error:
+ print(f"Error in [ZeroscopeTTV.forward]: {error}")
+ raise error
diff --git a/swarms/prompts/__init__.py b/swarms/prompts/__init__.py
index 6417dc85..dbdc7c7b 100644
--- a/swarms/prompts/__init__.py
+++ b/swarms/prompts/__init__.py
@@ -6,7 +6,7 @@ from swarms.prompts.operations_agent_prompt import (
OPERATIONS_AGENT_PROMPT,
)
from swarms.prompts.product_agent_prompt import PRODUCT_AGENT_PROMPT
-
+from swarms.prompts.documentation import DOCUMENTATION_WRITER_SOP
__all__ = [
"CODE_INTERPRETER",
@@ -15,4 +15,5 @@ __all__ = [
"LEGAL_AGENT_PROMPT",
"OPERATIONS_AGENT_PROMPT",
"PRODUCT_AGENT_PROMPT",
+ "DOCUMENTATION_WRITER_SOP",
]
diff --git a/swarms/prompts/documentation.py b/swarms/prompts/documentation.py
index 3ed2eb8c..3c0a588e 100644
--- a/swarms/prompts/documentation.py
+++ b/swarms/prompts/documentation.py
@@ -1,5 +1,8 @@
-def documentation(task: str):
- documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the code below follow the outline for the library,
+def DOCUMENTATION_WRITER_SOP(
+ task: str,
+ module: str,
+):
+ documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the {module} code below follow the outline for the {module} library,
provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words,
provide many usage examples and note this is markdown docs, create the documentation for the code to document,
put the arguments and methods in a table in markdown to make it visually seamless
diff --git a/swarms/prompts/tests.py b/swarms/prompts/tests.py
index df1ee92d..8dac9337 100644
--- a/swarms/prompts/tests.py
+++ b/swarms/prompts/tests.py
@@ -1,89 +1,95 @@
-TESTS_PROMPT = """
+def TEST_WRITER_SOP_PROMPT(
+ task: str, module: str, path: str, *args, **kwargs
+):
+ TESTS_PROMPT = f"""
-Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any
-just write the best tests possible:
+ Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any
+ just write the best tests possible, the module is {module}, the file path is {path}
-######### TESTING GUIDE #############
+ ######### TESTING GUIDE #############
-# **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`**
+ # **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`**
-1. **Preparation**:
- - Install pytest: `pip install pytest`.
- - Structure your project so that tests are in a separate `tests/` directory.
- - Name your test files with the prefix `test_` for pytest to recognize them.
+ 1. **Preparation**:
+ - Install pytest: `pip install pytest`.
+ - Structure your project so that tests are in a separate `tests/` directory.
+ - Name your test files with the prefix `test_` for pytest to recognize them.
-2. **Writing Basic Tests**:
- - Use clear function names prefixed with `test_` (e.g., `test_check_value()`).
- - Use assert statements to validate results.
+ 2. **Writing Basic Tests**:
+ - Use clear function names prefixed with `test_` (e.g., `test_check_value()`).
+ - Use assert statements to validate results.
-3. **Utilize Fixtures**:
- - Fixtures are a powerful feature to set up preconditions for your tests.
- - Use `@pytest.fixture` decorator to define a fixture.
- - Pass fixture name as an argument to your test to use it.
+ 3. **Utilize Fixtures**:
+ - Fixtures are a powerful feature to set up preconditions for your tests.
+ - Use `@pytest.fixture` decorator to define a fixture.
+ - Pass fixture name as an argument to your test to use it.
-4. **Parameterized Testing**:
- - Use `@pytest.mark.parametrize` to run a test multiple times with different inputs.
- - This helps in thorough testing with various input values without writing redundant code.
+ 4. **Parameterized Testing**:
+ - Use `@pytest.mark.parametrize` to run a test multiple times with different inputs.
+ - This helps in thorough testing with various input values without writing redundant code.
-5. **Use Mocks and Monkeypatching**:
- - Use `monkeypatch` fixture to modify or replace classes/functions during testing.
- - Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code.
+ 5. **Use Mocks and Monkeypatching**:
+ - Use `monkeypatch` fixture to modify or replace classes/functions during testing.
+ - Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code.
-6. **Exception Testing**:
- - Test for expected exceptions using `pytest.raises(ExceptionType)`.
+ 6. **Exception Testing**:
+ - Test for expected exceptions using `pytest.raises(ExceptionType)`.
-7. **Test Coverage**:
- - Install pytest-cov: `pip install pytest-cov`.
- - Run tests with `pytest --cov=my_module` to get a coverage report.
+ 7. **Test Coverage**:
+ - Install pytest-cov: `pip install pytest-cov`.
+ - Run tests with `pytest --cov=my_module` to get a coverage report.
-8. **Environment Variables and Secret Handling**:
- - Store secrets and configurations in environment variables.
- - Use libraries like `python-decouple` or `python-dotenv` to load environment variables.
- - For tests, mock or set environment variables temporarily within the test environment.
+ 8. **Environment Variables and Secret Handling**:
+ - Store secrets and configurations in environment variables.
+ - Use libraries like `python-decouple` or `python-dotenv` to load environment variables.
+ - For tests, mock or set environment variables temporarily within the test environment.
-9. **Grouping and Marking Tests**:
- - Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`).
- - This allows for selectively running certain groups of tests.
+ 9. **Grouping and Marking Tests**:
+ - Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`).
+ - This allows for selectively running certain groups of tests.
-10. **Use Plugins**:
- - Utilize the rich ecosystem of pytest plugins (e.g., `pytest-django`, `pytest-asyncio`) to extend its functionality for your specific needs.
+ 10. **Use Plugins**:
+ - Utilize the rich ecosystem of pytest plugins (e.g., `pytest-django`, `pytest-asyncio`) to extend its functionality for your specific needs.
-11. **Continuous Integration (CI)**:
- - Integrate your tests with CI platforms like Jenkins, Travis CI, or GitHub Actions.
- - Ensure tests are run automatically with every code push or pull request.
+ 11. **Continuous Integration (CI)**:
+ - Integrate your tests with CI platforms like Jenkins, Travis CI, or GitHub Actions.
+ - Ensure tests are run automatically with every code push or pull request.
-12. **Logging and Reporting**:
- - Use `pytest`'s inbuilt logging.
- - Integrate with tools like `Allure` for more comprehensive reporting.
+ 12. **Logging and Reporting**:
+ - Use `pytest`'s inbuilt logging.
+ - Integrate with tools like `Allure` for more comprehensive reporting.
-13. **Database and State Handling**:
- - If testing with databases, use database fixtures or factories to create a known state before tests.
- - Clean up and reset state post-tests to maintain consistency.
+ 13. **Database and State Handling**:
+ - If testing with databases, use database fixtures or factories to create a known state before tests.
+ - Clean up and reset state post-tests to maintain consistency.
-14. **Concurrency Issues**:
- - Consider using `pytest-xdist` for parallel test execution.
- - Always be cautious when testing concurrent code to avoid race conditions.
+ 14. **Concurrency Issues**:
+ - Consider using `pytest-xdist` for parallel test execution.
+ - Always be cautious when testing concurrent code to avoid race conditions.
-15. **Clean Code Practices**:
- - Ensure tests are readable and maintainable.
- - Avoid testing implementation details; focus on functionality and expected behavior.
+ 15. **Clean Code Practices**:
+ - Ensure tests are readable and maintainable.
+ - Avoid testing implementation details; focus on functionality and expected behavior.
-16. **Regular Maintenance**:
- - Periodically review and update tests.
- - Ensure that tests stay relevant as your codebase grows and changes.
+ 16. **Regular Maintenance**:
+ - Periodically review and update tests.
+ - Ensure that tests stay relevant as your codebase grows and changes.
-17. **Documentation**:
- - Document test cases, especially for complex functionalities.
- - Ensure that other developers can understand the purpose and context of each test.
+ 17. **Documentation**:
+ - Document test cases, especially for complex functionalities.
+ - Ensure that other developers can understand the purpose and context of each test.
-18. **Feedback Loop**:
- - Use test failures as feedback for development.
- - Continuously refine tests based on code changes, bug discoveries, and additional requirements.
+ 18. **Feedback Loop**:
+ - Use test failures as feedback for development.
+ - Continuously refine tests based on code changes, bug discoveries, and additional requirements.
-By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project.
+ By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project.
-######### CREATE TESTS FOR THIS CODE: #######
+ ######### CREATE TESTS FOR THIS CODE: #######
+ {task}
-"""
+ """
+
+ return TESTS_PROMPT
diff --git a/swarms/swarms/README.md b/swarms/structs/SWARMS.md
similarity index 100%
rename from swarms/swarms/README.md
rename to swarms/structs/SWARMS.md
diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py
index 4a58ea8d..f0388493 100644
--- a/swarms/structs/__init__.py
+++ b/swarms/structs/__init__.py
@@ -1,13 +1,27 @@
from swarms.structs.agent import Agent
-from swarms.structs.sequential_workflow import SequentialWorkflow
from swarms.structs.autoscaler import AutoScaler
+from swarms.structs.base_swarm import AbstractSwarm
from swarms.structs.conversation import Conversation
+from swarms.structs.groupchat import GroupChat, GroupChatManager
+from swarms.structs.model_parallizer import ModelParallelizer
+from swarms.structs.multi_agent_collab import MultiAgentCollaboration
from swarms.structs.schemas import (
- TaskInput,
Artifact,
ArtifactUpload,
StepInput,
+ TaskInput,
+)
+from swarms.structs.sequential_workflow import SequentialWorkflow
+from swarms.structs.swarm_net import SwarmNetwork
+from swarms.structs.utils import (
+ distribute_tasks,
+ extract_key_from_json,
+ extract_tokens_from_text,
+ find_agent_by_id,
+ find_token_in_text,
+ parse_tasks,
)
+from swarms.structs.concurrent_workflow import ConcurrentWorkflow
__all__ = [
"Agent",
@@ -18,4 +32,17 @@ __all__ = [
"Artifact",
"ArtifactUpload",
"StepInput",
+ "SwarmNetwork",
+ "ModelParallelizer",
+ "MultiAgentCollaboration",
+ "AbstractSwarm",
+ "GroupChat",
+ "GroupChatManager",
+ "parse_tasks",
+ "find_agent_by_id",
+ "distribute_tasks",
+ "find_token_in_text",
+ "extract_key_from_json",
+ "extract_tokens_from_text",
+ "ConcurrentWorkflow",
]
diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py
index be5c7121..3903d4ad 100644
--- a/swarms/structs/agent.py
+++ b/swarms/structs/agent.py
@@ -25,7 +25,7 @@ from swarms.tools.tool import BaseTool
from swarms.tools.tool_func_doc_scraper import scrape_tool_func_docs
from swarms.utils.code_interpreter import SubprocessCodeInterpreter
from swarms.utils.parse_code import (
- extract_code_in_backticks_in_string,
+ extract_code_from_markdown,
)
from swarms.utils.pdf_to_text import pdf_to_text
from swarms.utils.token_count_tiktoken import limit_tokens_from_string
@@ -63,98 +63,95 @@ class Agent:
* Ability to provide a loop interval
Args:
- id (str): The id of the agent
llm (Any): The language model to use
- template (Optional[str]): The template to use
- max_loops (int): The maximum number of loops
- stopping_condition (Optional[Callable[[str], bool]]): The stopping condition
+ template (str): The template to use
+ max_loops (int): The maximum number of loops to run
+ stopping_condition (Callable): The stopping condition to use
loop_interval (int): The loop interval
- retry_attempts (int): The retry attempts
+ retry_attempts (int): The number of retry attempts
retry_interval (int): The retry interval
return_history (bool): Return the history
stopping_token (str): The stopping token
- dynamic_loops (Optional[bool]): Dynamic loops
- interactive (bool): Interactive mode
- dashboard (bool): Dashboard mode
+ dynamic_loops (bool): Enable dynamic loops
+ interactive (bool): Enable interactive mode
+ dashboard (bool): Enable dashboard
agent_name (str): The name of the agent
agent_description (str): The description of the agent
system_prompt (str): The system prompt
- tools (List[BaseTool]): The tools
- dynamic_temperature_enabled (Optional[bool]): Dynamic temperature enabled
- sop (Optional[str]): The standard operating procedure
- sop_list (Optional[List[str]]): The standard operating procedure list
- saved_state_path (Optional[str]): The saved state path
- autosave (Optional[bool]): Autosave
- context_length (Optional[int]): The context length
+ tools (List[BaseTool]): The tools to use
+ dynamic_temperature_enabled (bool): Enable dynamic temperature
+ sop (str): The standard operating procedure
+ sop_list (List[str]): The standard operating procedure list
+ saved_state_path (str): The path to the saved state
+ autosave (bool): Autosave the state
+ context_length (int): The context length
user_name (str): The user name
- self_healing_enabled (Optional[bool]): Self healing enabled
- code_interpreter (Optional[bool]): Code interpreter
- multi_modal (Optional[bool]): Multi modal
- pdf_path (Optional[str]): The pdf path
- list_of_pdf (Optional[str]): The list of pdf
- tokenizer (Optional[Any]): The tokenizer
- memory (Optional[VectorDatabase]): The memory
- preset_stopping_token (Optional[bool]): Preset stopping token
- *args: Variable length argument list.
- **kwargs: Arbitrary keyword arguments.
+ self_healing_enabled (bool): Enable self healing
+ code_interpreter (bool): Enable code interpreter
+ multi_modal (bool): Enable multimodal
+ pdf_path (str): The path to the pdf
+ list_of_pdf (str): The list of pdf
+ tokenizer (Any): The tokenizer
+ memory (VectorDatabase): The memory
+ preset_stopping_token (bool): Enable preset stopping token
+ traceback (Any): The traceback
+ traceback_handlers (Any): The traceback handlers
+ streaming_on (bool): Enable streaming
Methods:
- run(task: str, **kwargs: Any): Run the agent on a task
- run_concurrent(tasks: List[str], **kwargs: Any): Run the agent on a list of tasks concurrently
- bulk_run(inputs: List[Dict[str, Any]]): Run the agent on a list of inputs
- from_llm_and_template(llm: Any, template: str): Create AgentStream from LLM and a string template.
- from_llm_and_template_file(llm: Any, template_file: str): Create AgentStream from LLM and a template file.
- save(file_path): Save the agent history to a file
- load(file_path): Load the agent history from a file
- validate_response(response: str): Validate the response based on certain criteria
- print_history_and_memory(): Print the entire history and memory of the agent
- step(task: str, **kwargs): Executes a single step in the agent interaction, generating a response from the language model based on the given input text.
- graceful_shutdown(): Gracefully shutdown the system saving the state
- run_with_timeout(task: str, timeout: int): Run the loop but stop if it takes longer than the timeout
- analyze_feedback(): Analyze the feedback for issues
- undo_last(): Response the last response and return the previous state
- add_response_filter(filter_word: str): Add a response filter to filter out certain words from the response
- apply_reponse_filters(response: str): Apply the response filters to the response
- filtered_run(task: str): Filtered run
- interactive_run(max_loops: int): Interactive run mode
- streamed_generation(prompt: str): Stream the generation of the response
- get_llm_params(): Extracts and returns the parameters of the llm object for serialization.
- save_state(file_path: str): Saves the current state of the agent to a JSON file, including the llm parameters.
- load_state(file_path: str): Loads the state of the agent from a json file and restores the configuration and memory.
- retry_on_failure(function, retries: int = 3, retry_delay: int = 1): Retry wrapper for LLM calls.
- run_code(response: str): Run the code in the response
- construct_dynamic_prompt(): Construct the dynamic prompt
- extract_tool_commands(text: str): Extract the tool commands from the text
- parse_and_execute_tools(response: str): Parse and execute the tools
- execute_tools(tool_name, params): Execute the tool with the provided params
- truncate_history(): Take the history and truncate it to fit into the model context length
- add_task_to_memory(task: str): Add the task to the memory
- add_message_to_memory(message: str): Add the message to the memory
- add_message_to_memory_and_truncate(message: str): Add the message to the memory and truncate
- print_dashboard(task: str): Print dashboard
- activate_autonomous_agent(): Print the autonomous agent activation message
- dynamic_temperature(): Dynamically change the temperature
- _check_stopping_condition(response: str): Check if the stopping condition is met
- format_prompt(template, **kwargs: Any): Format the template with the provided kwargs using f-string interpolation.
- get_llm_init_params(): Get LLM init params
- get_tool_description(): Get the tool description
- find_tool_by_name(name: str): Find a tool by name
-
-
- Example:
+ run: Run the agent
+ run_concurrent: Run the agent concurrently
+ bulk_run: Run the agent in bulk
+ save: Save the agent
+ load: Load the agent
+ validate_response: Validate the response
+ print_history_and_memory: Print the history and memory
+ step: Step through the agent
+ graceful_shutdown: Gracefully shutdown the agent
+ run_with_timeout: Run the agent with a timeout
+ analyze_feedback: Analyze the feedback
+ undo_last: Undo the last response
+ add_response_filter: Add a response filter
+ apply_response_filters: Apply the response filters
+ filtered_run: Run the agent with filtered responses
+ interactive_run: Run the agent in interactive mode
+ streamed_generation: Stream the generation of the response
+ get_llm_params: Get the llm parameters
+ save_state: Save the state
+ load_state: Load the state
+ get_llm_init_params: Get the llm init parameters
+ get_tool_description: Get the tool description
+ find_tool_by_name: Find a tool by name
+ extract_tool_commands: Extract the tool commands
+ execute_tools: Execute the tools
+ parse_and_execute_tools: Parse and execute the tools
+ truncate_history: Truncate the history
+ add_task_to_memory: Add the task to the memory
+ add_message_to_memory: Add the message to the memory
+ add_message_to_memory_and_truncate: Add the message to the memory and truncate
+ parse_tool_docs: Parse the tool docs
+ print_dashboard: Print the dashboard
+ loop_count_print: Print the loop count
+ streaming: Stream the content
+ _history: Generate the history
+ _dynamic_prompt_setup: Setup the dynamic prompt
+ agent_system_prompt_2: Agent system prompt 2
+ run_async: Run the agent asynchronously
+ run_async_concurrent: Run the agent asynchronously and concurrently
+ run_async_concurrent: Run the agent asynchronously and concurrently
+ construct_dynamic_prompt: Construct the dynamic prompt
+ construct_dynamic_prompt: Construct the dynamic prompt
+
+
+ Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import Agent
- >>> llm = OpenAIChat(
- ... openai_api_key=api_key,
- ... temperature=0.5,
- ... )
- >>> agent = Agent(
- ... llm=llm, max_loops=5,
- ... #system_prompt=SYSTEM_PROMPT,
- ... #retry_interval=1,
- ... )
- >>> agent.run("Generate a 10,000 word blog")
- >>> agent.save("path/agent.yaml")
+ >>> llm = OpenAIChat()
+ >>> agent = Agent(llm=llm, max_loops=1)
+ >>> response = agent.run("Generate a report on the financials.")
+ >>> print(response)
+ >>> # Generate a report on the financials.
+
"""
def __init__(
@@ -172,7 +169,7 @@ class Agent:
dynamic_loops: Optional[bool] = False,
interactive: bool = False,
dashboard: bool = False,
- agent_name: str = "Autonomous-Agent-XYZ1B",
+ agent_name: str = None,
agent_description: str = None,
system_prompt: str = AGENT_SYSTEM_PROMPT_3,
tools: List[BaseTool] = None,
@@ -1257,7 +1254,7 @@ class Agent:
"""
text -> parse_code by looking for code inside 6 backticks `````-> run_code
"""
- parsed_code = extract_code_in_backticks_in_string(code)
+ parsed_code = extract_code_from_markdown(code)
run_code = self.code_executor.run(parsed_code)
return run_code
diff --git a/swarms/structs/all_to_all_swarm.py b/swarms/structs/all_to_all_swarm.py
new file mode 100644
index 00000000..e69de29b
diff --git a/swarms/swarms/base.py b/swarms/structs/base_swarm.py
similarity index 100%
rename from swarms/swarms/base.py
rename to swarms/structs/base_swarm.py
diff --git a/swarms/structs/concurrent_workflow.py b/swarms/structs/concurrent_workflow.py
new file mode 100644
index 00000000..05a595e6
--- /dev/null
+++ b/swarms/structs/concurrent_workflow.py
@@ -0,0 +1,96 @@
+import concurrent.futures
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional
+
+from swarms.structs.base import BaseStruct
+from swarms.structs.task import Task
+
+
+@dataclass
+class ConcurrentWorkflow(BaseStruct):
+ """
+ ConcurrentWorkflow class for running a set of tasks concurrently using N number of autonomous agents.
+
+ Args:
+ max_workers (int): The maximum number of workers to use for concurrent execution.
+ autosave (bool): Whether to autosave the workflow state.
+ saved_state_filepath (Optional[str]): The file path to save the workflow state.
+
+ Attributes:
+ tasks (List[Task]): The list of tasks to execute.
+ max_workers (int): The maximum number of workers to use for concurrent execution.
+ autosave (bool): Whether to autosave the workflow state.
+ saved_state_filepath (Optional[str]): The file path to save the workflow state.
+
+ Examples:
+ >>> from swarms.models import OpenAIChat
+ >>> from swarms.structs import ConcurrentWorkflow
+ >>> llm = OpenAIChat(openai_api_key="")
+ >>> workflow = ConcurrentWorkflow(max_workers=5)
+ >>> workflow.add("What's the weather in miami", llm)
+ >>> workflow.add("Create a report on these metrics", llm)
+ >>> workflow.run()
+ >>> workflow.tasks
+ """
+
+ tasks: List[Dict] = field(default_factory=list)
+ max_workers: int = 5
+ autosave: bool = False
+ saved_state_filepath: Optional[str] = (
+ "runs/concurrent_workflow.json"
+ )
+ print_results: bool = False
+ return_results: bool = False
+ use_processes: bool = False
+
+ def add(self, task: Task):
+ """Adds a task to the workflow.
+
+ Args:
+ task (Task): _description_
+ """
+ self.tasks.append(task)
+
+ def run(self):
+ """
+ Executes the tasks in parallel using a ThreadPoolExecutor.
+
+ Args:
+ print_results (bool): Whether to print the results of each task. Default is False.
+ return_results (bool): Whether to return the results of each task. Default is False.
+
+ Returns:
+ List[Any]: A list of the results of each task, if return_results is True. Otherwise, returns None.
+ """
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=self.max_workers
+ ) as executor:
+ futures = {
+ executor.submit(task.execute): task
+ for task in self.tasks
+ }
+ results = []
+
+ for future in concurrent.futures.as_completed(futures):
+ task = futures[future]
+ try:
+ result = future.result()
+ if self.print_results:
+ print(f"Task {task}: {result}")
+ if self.return_results:
+ results.append(result)
+ except Exception as e:
+ print(f"Task {task} generated an exception: {e}")
+
+ return results if self.return_results else None
+
+ def _execute_task(self, task: Task):
+ """Executes a task.
+
+ Args:
+ task (Task): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ return task.run()
diff --git a/swarms/swarms/groupchat.py b/swarms/structs/groupchat.py
similarity index 98%
rename from swarms/swarms/groupchat.py
rename to swarms/structs/groupchat.py
index f3677023..21fff944 100644
--- a/swarms/swarms/groupchat.py
+++ b/swarms/structs/groupchat.py
@@ -140,7 +140,6 @@ class GroupChatManager:
>>> from swarms import GroupChatManager
>>> from swarms.structs.agent import Agent
>>> agents = Agent()
- >>> output = GroupChatManager(agents, lambda x: x)
"""
diff --git a/swarms/swarms/god_mode.py b/swarms/structs/model_parallizer.py
similarity index 71%
rename from swarms/swarms/god_mode.py
rename to swarms/structs/model_parallizer.py
index 29178b2c..3844f5b4 100644
--- a/swarms/swarms/god_mode.py
+++ b/swarms/structs/model_parallizer.py
@@ -11,17 +11,17 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-class GodMode:
+class ModelParallelizer:
"""
- GodMode
+ ModelParallelizer
-----
Architecture:
How it works:
- 1. GodMode receives a task from the user.
- 2. GodMode distributes the task to all LLMs.
- 3. GodMode collects the responses from all LLMs.
- 4. GodMode prints the responses from all LLMs.
+ 1. ModelParallelizer receives a task from the user.
+ 2. ModelParallelizer distributes the task to all LLMs.
+ 3. ModelParallelizer collects the responses from all LLMs.
+ 4. ModelParallelizer prints the responses from all LLMs.
Parameters:
llms: list of LLMs
@@ -31,30 +31,42 @@ class GodMode:
print_responses(task): print responses from all LLMs
Usage:
- god_mode = GodMode(llms)
- god_mode.run(task)
- god_mode.print_responses(task)
+ parallelizer = ModelParallelizer(llms)
+ parallelizer.run(task)
+ parallelizer.print_responses(task)
"""
def __init__(
self,
- llms: List[Callable],
+ llms: List[Callable] = None,
load_balancing: bool = False,
retry_attempts: int = 3,
+ iters: int = None,
+ *args,
+ **kwargs,
):
self.llms = llms
self.load_balancing = load_balancing
self.retry_attempts = retry_attempts
+ self.iters = iters
self.last_responses = None
self.task_history = []
def run(self, task: str):
"""Run the task string"""
- with ThreadPoolExecutor() as executor:
- responses = executor.map(lambda llm: llm(task), self.llms)
- return list(responses)
+ try:
+ for i in range(self.iters):
+ with ThreadPoolExecutor() as executor:
+ responses = executor.map(
+ lambda llm: llm(task), self.llms
+ )
+ return list(responses)
+ except Exception as error:
+ print(
+ f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
+ )
def print_responses(self, task):
"""Prints the responses in a tabular format"""
@@ -161,22 +173,29 @@ class GodMode:
def concurrent_run(self, task: str) -> List[str]:
"""Synchronously run the task on all llms and collect responses"""
- with ThreadPoolExecutor() as executor:
- future_to_llm = {
- executor.submit(llm, task): llm for llm in self.llms
- }
- responses = []
- for future in as_completed(future_to_llm):
- try:
- responses.append(future.result())
- except Exception as error:
- print(
- f"{future_to_llm[future]} generated an"
- f" exception: {error}"
- )
- self.last_responses = responses
- self.task_history.append(task)
- return responses
+ try:
+ with ThreadPoolExecutor() as executor:
+ future_to_llm = {
+ executor.submit(llm, task): llm
+ for llm in self.llms
+ }
+ responses = []
+ for future in as_completed(future_to_llm):
+ try:
+ responses.append(future.result())
+ except Exception as error:
+ print(
+ f"{future_to_llm[future]} generated an"
+ f" exception: {error}"
+ )
+ self.last_responses = responses
+ self.task_history.append(task)
+ return responses
+ except Exception as error:
+ print(
+ f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
+ )
+ raise error
def add_llm(self, llm: Callable):
"""Add an llm to the god mode"""
diff --git a/swarms/swarms/multi_agent_collab.py b/swarms/structs/multi_agent_collab.py
similarity index 100%
rename from swarms/swarms/multi_agent_collab.py
rename to swarms/structs/multi_agent_collab.py
diff --git a/swarms/structs/swarm_net.py b/swarms/structs/swarm_net.py
new file mode 100644
index 00000000..4a0ae0de
--- /dev/null
+++ b/swarms/structs/swarm_net.py
@@ -0,0 +1,320 @@
+import asyncio
+import logging
+import queue
+import threading
+from typing import List, Optional
+
+from fastapi import FastAPI
+
+from swarms.structs.agent import Agent
+from swarms.structs.base import BaseStructure
+
+
+class SwarmNetwork(BaseStructure):
+ """
+ SwarmNetwork class
+
+ The SwarmNetwork class is responsible for managing the agents pool
+ and the task queue. It also monitors the health of the agents and
+ scales the pool up or down based on the number of pending tasks
+ and the current load of the agents.
+
+ For example, if the number of pending tasks is greater than the
+ number of agents in the pool, the SwarmNetwork will scale up the
+ pool by adding new agents. If the number of pending tasks is less
+ than the number of agents in the pool, the SwarmNetwork will scale
+ down the pool by removing agents.
+
+ The SwarmNetwork class also provides a simple API for interacting
+ with the agents pool. The API is implemented using the Flask
+ framework and is enabled by default. The API can be disabled by
+ setting the `api_enabled` parameter to False.
+
+ Features:
+ - Agent pool management
+ - Task queue management
+ - Agent health monitoring
+ - Agent pool scaling
+ - Simple API for interacting with the agent pool
+ - Simple API for interacting with the task queue
+ - Simple API for interacting with the agent health monitor
+ - Simple API for interacting with the agent pool scaler
+ - Create APIs for each agent in the pool (optional)
+ - Run each agent on it's own thread
+ - Run each agent on it's own process
+ - Run each agent on it's own container
+ - Run each agent on it's own machine
+ - Run each agent on it's own cluster
+
+
+ """
+
+ def __init__(
+ self,
+ idle_threshold: float = 0.2,
+ busy_threshold: float = 0.7,
+ agents: List[Agent] = None,
+ api_enabled: Optional[bool] = False,
+ logging_enabled: Optional[bool] = False,
+ *args,
+ **kwargs,
+ ):
+ self.task_queue = queue.Queue()
+ self.idle_threshold = idle_threshold
+ self.busy_threshold = busy_threshold
+ self.lock = threading.Lock()
+ self.agents = agents
+ self.api_enabled = api_enabled
+ self.logging_enabled = logging_enabled
+
+ logging.basicConfig(level=logging.INFO)
+ self.logger = logging.getLogger(__name__)
+
+ if api_enabled:
+ self.api = FastAPI()
+
+ self.agent_pool = []
+
+ def add_task(self, task):
+ """Add task to the task queue
+
+ Args:
+ task (_type_): _description_
+
+ Example:
+ >>> from swarms.structs.agent import Agent
+ >>> from swarms.structs.swarm_net import SwarmNetwork
+ >>> agent = Agent()
+ >>> swarm = SwarmNetwork(agents=[agent])
+ >>> swarm.add_task("task")
+
+ """
+ self.logger.info(f"Adding task {task} to queue")
+ try:
+ self.task_queue.put(task)
+ self.logger.info(f"Task {task} added to queue")
+ except Exception as error:
+ print(
+ f"Error adding task to queue: {error} try again with"
+ " a new task"
+ )
+ raise error
+
+ async def async_add_task(self, task):
+ """Add task to the task queue
+
+ Args:
+ task (_type_): _description_
+
+ Example:
+ >>> from swarms.structs.agent import Agent
+ >>> from swarms.structs.swarm_net import SwarmNetwork
+ >>> agent = Agent()
+ >>> swarm = SwarmNetwork(agents=[agent])
+ >>> swarm.add_task("task")
+
+ """
+ self.logger.info(
+ f"Adding task {task} to queue asynchronously"
+ )
+ try:
+ # Add task to queue asynchronously with asyncio
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(
+ None, self.task_queue.put, task
+ )
+ self.logger.info(f"Task {task} added to queue")
+ except Exception as error:
+ print(
+ f"Error adding task to queue: {error} try again with"
+ " a new task"
+ )
+ raise error
+
+ def run_single_agent(
+ self, agent_id, task: Optional[str] = None, *args, **kwargs
+ ):
+ """Run agent the task on the agent id
+
+ Args:
+ agent_id (_type_): _description_
+ task (str, optional): _description_. Defaults to None.
+
+ Raises:
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+ self.logger.info(f"Running task {task} on agent {agent_id}")
+ try:
+ for agent in self.agents_pool:
+ if agent.id == agent_id:
+ return agent.run(task, *args, **kwargs)
+ self.logger.info(f"No agent found with ID {agent_id}")
+ raise ValueError(f"No agent found with ID {agent_id}")
+ except Exception as error:
+ print(f"Error running task on agent: {error}")
+ raise error
+
+ def run_many_agents(
+ self, task: Optional[str] = None, *args, **kwargs
+ ) -> List:
+ """Run the task on all agents
+
+ Args:
+ task (str, optional): _description_. Defaults to None.
+
+ Returns:
+ List: _description_
+ """
+ self.logger.info(f"Running task {task} on all agents")
+ try:
+ return [
+ agent.run(task, *args, **kwargs)
+ for agent in self.agents_pool
+ ]
+ except Exception as error:
+ print(f"Error running task on agents: {error}")
+ raise error
+
+ def list_agents(self):
+ """List all agents
+
+ Returns:
+ List: _description_
+ """
+ self.logger.info("[Listing all active agents]")
+ try:
+ # return [agent.id for agent in self.agents_pool]
+ for agent in self.agents:
+ num_agents = len(self.agents)
+ self.logger.info(
+ f"[Number of active agents: {num_agents}]"
+ )
+ return self.logger.info(
+ f"[Agent] [ID: {agent.id}] [Name:"
+ f" {agent.agent_name}] [Description:"
+ f" {agent.agent_description}] [Status] [Running]"
+ )
+ except Exception as error:
+ print(f"Error listing agents: {error}")
+ raise error
+
+ def get_agent(self, agent_id):
+ """Get agent by id
+
+ Args:
+ agent_id (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ self.logger.info(f"Getting agent {agent_id}")
+
+ try:
+ for agent in self.agents_pool:
+ if agent.id == agent_id:
+ return agent
+ raise ValueError(f"No agent found with ID {agent_id}")
+ except Exception as error:
+ self.logger.error(f"Error getting agent: {error}")
+ raise error
+
+ def add_agent(self, agent):
+ """Add agent to the agent pool
+
+ Args:
+ agent (_type_): _description_
+ """
+ self.logger.info(f"Adding agent {agent} to pool")
+ try:
+ self.agents_pool.append(agent)
+ except Exception as error:
+ print(f"Error adding agent to pool: {error}")
+ raise error
+
+ def remove_agent(self, agent_id):
+ """Remove agent from the agent pool
+
+ Args:
+ agent_id (_type_): _description_
+ """
+ self.logger.info(f"Removing agent {agent_id} from pool")
+ try:
+ for agent in self.agents_pool:
+ if agent.id == agent_id:
+ self.agents_pool.remove(agent)
+ return
+ raise ValueError(f"No agent found with ID {agent_id}")
+ except Exception as error:
+ print(f"Error removing agent from pool: {error}")
+ raise error
+
+ async def async_remove_agent(self, agent_id):
+ """Remove agent from the agent pool
+
+ Args:
+ agent_id (_type_): _description_
+ """
+ self.logger.info(f"Removing agent {agent_id} from pool")
+ try:
+ # Remove agent from pool asynchronously with asyncio
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(
+ None, self.remove_agent, agent_id
+ )
+ except Exception as error:
+ print(f"Error removing agent from pool: {error}")
+ raise error
+
+ def scale_up(self, num_agents: int = 1):
+ """Scale up the agent pool
+
+ Args:
+ num_agents (int, optional): _description_. Defaults to 1.
+ """
+ self.logger.info(f"Scaling up agent pool by {num_agents}")
+ try:
+ for _ in range(num_agents):
+ self.agents_pool.append(Agent())
+ except Exception as error:
+ print(f"Error scaling up agent pool: {error}")
+ raise error
+
+ def scale_down(self, num_agents: int = 1):
+ """Scale down the agent pool
+
+ Args:
+ num_agents (int, optional): _description_. Defaults to 1.
+ """
+ for _ in range(num_agents):
+ self.agents_pool.pop()
+
+ # - Create APIs for each agent in the pool (optional) with fastapi
+ def create_apis_for_agents(self):
+ """Create APIs for each agent in the pool (optional) with fastapi
+
+ Returns:
+ _type_: _description_
+ """
+ self.apis = []
+ for agent in self.agents:
+ self.api.get(f"/{agent.id}")
+
+ def run_agent(task: str, *args, **kwargs):
+ return agent.run(task, *args, **kwargs)
+
+ self.apis.append(self.api)
+
+ def run(self):
+ """run the swarm network"""
+ # Observe all agents in the pool
+ self.logger.info("Starting the SwarmNetwork")
+
+ for agent in self.agents:
+ self.logger.info(f"Starting agent {agent.id}")
+ self.logger.info(
+ f"[Agent][{agent.id}] [Status] [Running] [Awaiting"
+ " Task]"
+ )
diff --git a/swarms/structs/task.py b/swarms/structs/task.py
index 81351b4f..9c0f8dac 100644
--- a/swarms/structs/task.py
+++ b/swarms/structs/task.py
@@ -1,4 +1,7 @@
+import sched
+import time
from dataclasses import dataclass, field
+from datetime import datetime
from typing import (
Any,
Callable,
@@ -10,23 +13,36 @@ from typing import (
from swarms.structs.agent import Agent
-# Define a generic Task that can handle different types of callable objects
@dataclass
class Task:
"""
Task class for running a task in a sequential workflow.
-
- Args:
- description (str): The description of the task.
- agent (Union[Callable, Agent]): The model or agent to execute the task.
- args (List[Any]): Additional arguments to pass to the task execution.
- kwargs (Dict[str, Any]): Additional keyword arguments to pass to the task execution.
- result (Any): The result of the task execution.
- history (List[Any]): The history of the task execution.
+ Attributes:
+ description (str): Description of the task.
+ agent (Union[Callable, Agent]): Agent or callable object to run the task.
+ args (List[Any]): Arguments to pass to the agent or callable object.
+ kwargs (Dict[str, Any]): Keyword arguments to pass to the agent or callable object.
+ result (Any): Result of the task.
+ history (List[Any]): History of the task.
+ schedule_time (datetime): Time to schedule the task.
+ scheduler (sched.scheduler): Scheduler to schedule the task.
+ trigger (Callable): Trigger to run the task.
+ action (Callable): Action to run the task.
+ condition (Callable): Condition to run the task.
+ priority (int): Priority of the task.
+ dependencies (List[Task]): List of tasks that need to be completed before this task can be executed.
Methods:
- execute: Execute the task.
+ execute: Execute the task by calling the agent or model with the arguments and keyword arguments.
+ handle_scheduled_task: Handles the execution of a scheduled task.
+ set_trigger: Sets the trigger for the task.
+ set_action: Sets the action for the task.
+ set_condition: Sets the condition for the task.
+ is_completed: Checks whether the task has been completed.
+ add_dependency: Adds a task to the list of dependencies.
+ set_priority: Sets the priority of the task.
+ check_dependency_completion: Checks whether all the dependencies have been completed.
Examples:
@@ -45,34 +61,140 @@ class Task:
kwargs: Dict[str, Any] = field(default_factory=dict)
result: Any = None
history: List[Any] = field(default_factory=list)
- # logger = logging.getLogger(__name__)
+ schedule_time: datetime = None
+ scheduler = sched.scheduler(time.time, time.sleep)
+ trigger: Callable = None
+ action: Callable = None
+ condition: Callable = None
+ priority: int = 0
+ dependencies: List["Task"] = field(default_factory=list)
def execute(self):
"""
- Execute the task.
+ Execute the task by calling the agent or model with the arguments and
+ keyword arguments.
+
+ Examples:
+ >>> from swarms.structs import Task, Agent
+ >>> from swarms.models import OpenAIChat
+ >>> agent = Agent(llm=OpenAIChat(openai_api_key=""), max_loops=1, dashboard=False)
+ >>> task = Task(description="What's the weather in miami", agent=agent)
+ >>> task.execute()
+ >>> task.result
- Raises:
- ValueError: If a Agent instance is used as a task and the 'task' argument is not provided.
"""
- if isinstance(self.agent, Agent):
- # Add a prompt to notify the Agent of the sequential workflow
- if "prompt" in self.kwargs:
- self.kwargs["prompt"] += (
- f"\n\nPrevious output: {self.result}"
- if self.result
- else ""
- )
- else:
- self.kwargs["prompt"] = (
- f"Main task: {self.description}"
- + (
- f"\n\nPrevious output: {self.result}"
- if self.result
- else ""
+
+ try:
+ if isinstance(self.agent, Agent):
+ if self.condition is None or self.condition():
+ self.result = self.agent.run(
+ *self.args, **self.kwargs
)
+ self.history.append(self.result)
+
+ if self.action is not None:
+ self.action()
+ else:
+ self.result = self.agent.run(
+ *self.args, **self.kwargs
)
- self.result = self.agent.run(*self.args, **self.kwargs)
- else:
- self.result = self.agent(*self.args, **self.kwargs)
- self.history.append(self.result)
+ self.history.append(self.result)
+ except Exception as error:
+ print(f"[ERROR][Task] {error}")
+
+ def run(self):
+ self.execute()
+
+ def __call__(self):
+ self.execute()
+
+ def handle_scheduled_task(self):
+ """
+ Handles the execution of a scheduled task.
+
+ If the schedule time is not set or has already passed, the task is executed immediately.
+ Otherwise, the task is scheduled to be executed at the specified schedule time.
+ """
+ try:
+ if (
+ self.schedule_time is None
+ or self.schedule_time <= datetime.now()
+ ):
+ self.execute()
+
+ else:
+ delay = (
+ self.schedule_time - datetime.now()
+ ).total_seconds()
+ self.scheduler.enter(delay, 1, self.execute)
+ self.scheduler_run()
+ except Exception as error:
+ print(f"[ERROR][Task] {error}")
+
+ def set_trigger(self, trigger: Callable):
+ """
+ Sets the trigger for the task.
+
+ Args:
+ trigger (Callable): The trigger to set.
+ """
+ self.trigger = trigger
+
+ def set_action(self, action: Callable):
+ """
+ Sets the action for the task.
+
+ Args:
+ action (Callable): The action to set.
+ """
+ self.action = action
+
+ def set_condition(self, condition: Callable):
+ """
+ Sets the condition for the task.
+
+ Args:
+ condition (Callable): The condition to set.
+ """
+ self.condition = condition
+
+ def is_completed(self):
+ """Is the task completed?
+
+ Returns:
+ _type_: _description_
+ """
+ return self.result is not None
+
+ def add_dependency(self, task):
+ """Adds a task to the list of dependencies.
+
+ Args:
+ task (_type_): _description_
+ """
+ self.dependencies.append(task)
+
+ def set_priority(self, priority: int):
+ """Sets the priority of the task.
+
+ Args:
+ priority (int): _description_
+ """
+ self.priority = priority
+
+ def check_dependency_completion(self):
+ """
+ Checks whether all the dependencies have been completed.
+
+ Returns:
+ bool: True if all the dependencies have been completed, False otherwise.
+ """
+ try:
+ for task in self.dependencies:
+ if not task.is_completed():
+ return False
+ except Exception as error:
+ print(
+ f"[ERROR][Task][check_dependency_completion] {error}"
+ )
diff --git a/swarms/structs/team.py b/swarms/structs/team.py
new file mode 100644
index 00000000..36c773e2
--- /dev/null
+++ b/swarms/structs/team.py
@@ -0,0 +1,107 @@
+import json
+from typing import List, Optional
+
+from pydantic.v1 import BaseModel, Field, Json, root_validator
+
+from swarms.structs.agent import Agent
+from swarms.structs.task import Task
+
+
+class Team(BaseModel):
+ """
+ Class that represents a group of agents, how they should work together and
+ their tasks.
+
+ Attributes:
+ tasks (Optional[List[Task]]): List of tasks.
+ agents (Optional[List[Agent]]): List of agents in this Team.
+ architecture (str): Architecture that the Team will follow. Default is "sequential".
+ verbose (bool): Verbose mode for the Agent Execution. Default is False.
+ config (Optional[Json]): Configuration of the Team. Default is None.
+ """
+
+ tasks: Optional[List[Task]] = Field(description="List of tasks")
+ agents: Optional[List[Agent]] = Field(
+ description="List of agents in this Team."
+ )
+ architecture = Field(
+ description="architecture that the Team will follow.",
+ default="sequential",
+ )
+ verbose: bool = Field(
+ description="Verbose mode for the Agent Execution",
+ default=False,
+ )
+ config: Optional[Json] = Field(
+ description="Configuration of the Team.", default=None
+ )
+
+ @root_validator(pre=True)
+ def check_config(_cls, values):
+ if not values.get("config") and (
+ not values.get("agents") and not values.get("tasks")
+ ):
+ raise ValueError(
+ "Either agents and task need to be set or config."
+ )
+
+ if values.get("config"):
+ config = json.loads(values.get("config"))
+ if not config.get("agents") or not config.get("tasks"):
+ raise ValueError(
+ "Config should have agents and tasks."
+ )
+
+ values["agents"] = [
+ Agent(**agent) for agent in config["agents"]
+ ]
+
+ tasks = []
+ for task in config["tasks"]:
+ task_agent = [
+ agt
+ for agt in values["agents"]
+ if agt.role == task["agent"]
+ ][0]
+ del task["agent"]
+ tasks.append(Task(**task, agent=task_agent))
+
+ values["tasks"] = tasks
+ return values
+
+ def run(self) -> str:
+ """
+ Kickoff the Team to work on its tasks.
+
+ Returns:
+ output (List[str]): Output of the Team for each task.
+ """
+ if self.architecture == "sequential":
+ return self.__sequential_loop()
+
+ def __sequential_loop(self) -> str:
+ """
+ Loop that executes the sequential architecture.
+
+ Returns:
+ output (str): Output of the Team.
+ """
+ task_outcome = None
+ for task in self.tasks:
+ # Add delegation tools to the task if the agent allows it
+ # if task.agent.allow_delegation:
+ # tools = AgentTools(agents=self.agents).tools()
+ # task.tools += tools
+
+ self.__log(f"\nWorking Agent: {task.agent.role}")
+ self.__log(f"Starting Task: {task.description} ...")
+
+ task_outcome = task.execute(task_outcome)
+
+ self.__log(f"Task output: {task_outcome}")
+
+ return task_outcome
+
+ def __log(self, message):
+ if self.verbose:
+ print(message)
diff --git a/swarms/swarms/utils.py b/swarms/structs/utils.py
similarity index 59%
rename from swarms/swarms/utils.py
rename to swarms/structs/utils.py
index 73da08df..3afb5fea 100644
--- a/swarms/swarms/utils.py
+++ b/swarms/structs/utils.py
@@ -1,4 +1,5 @@
-from typing import Dict, Any, List
+import json
+from typing import Dict, Any, List, Optional
from swarms.structs.agent import Agent
@@ -66,3 +67,54 @@ def distribute_tasks(
f"No agent found with ID {agent_id}. Task '{task}' is"
" not assigned."
)
+
+
+def find_token_in_text(text: str, token: str = "") -> bool:
+ """
+ Parse a block of text for a specific token.
+
+ Args:
+ text (str): The text to parse.
+ token (str): The token to find.
+
+ Returns:
+ bool: True if the token is found in the text, False otherwise.
+ """
+ # Check if the token is in the text
+ if token in text:
+ return True
+ else:
+ return False
+
+
+def extract_key_from_json(
+ json_response: str, key: str
+) -> Optional[str]:
+ """
+ Extract a specific key from a JSON response.
+
+ Args:
+ json_response (str): The JSON response to parse.
+ key (str): The key to extract.
+
+ Returns:
+ Optional[str]: The value of the key if it exists, None otherwise.
+ """
+ response_dict = json.loads(json_response)
+ return response_dict.get(key)
+
+
+def extract_tokens_from_text(
+ text: str, tokens: List[str]
+) -> List[str]:
+ """
+ Extract a list of tokens from a text response.
+
+ Args:
+ text (str): The text to parse.
+ tokens (List[str]): The tokens to extract.
+
+ Returns:
+ List[str]: The tokens that were found in the text.
+ """
+ return [token for token in tokens if token in text]
diff --git a/swarms/swarms/__init__.py b/swarms/swarms/__init__.py
deleted file mode 100644
index 38ced622..00000000
--- a/swarms/swarms/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from swarms.structs.autoscaler import AutoScaler
-from swarms.swarms.god_mode import GodMode
-from swarms.swarms.multi_agent_collab import MultiAgentCollaboration
-from swarms.swarms.base import AbstractSwarm
-
-__all__ = [
- "AutoScaler",
- "GodMode",
- "MultiAgentCollaboration",
- "AbstractSwarm",
-]
diff --git a/swarms/telemetry/auto_upgrade_swarms.py b/swarms/telemetry/auto_upgrade_swarms.py
new file mode 100644
index 00000000..aead795b
--- /dev/null
+++ b/swarms/telemetry/auto_upgrade_swarms.py
@@ -0,0 +1,11 @@
+import subprocess
+from swarms.telemetry.check_update import check_for_update
+
+
+def auto_update():
+ """auto update swarms"""
+ try:
+ if check_for_update():
+ subprocess.run(["pip", "install", "--upgrade", "swarms"])
+ except Exception as e:
+ print(e)
diff --git a/swarms/telemetry/check_update.py b/swarms/telemetry/check_update.py
new file mode 100644
index 00000000..a9b6386e
--- /dev/null
+++ b/swarms/telemetry/check_update.py
@@ -0,0 +1,46 @@
+import pkg_resources
+import requests
+from packaging import version
+
+import importlib.util
+import sys
+
+
+# borrowed from: https://stackoverflow.com/a/1051266/656011
+def check_for_package(package):
+ if package in sys.modules:
+ return True
+ elif (spec := importlib.util.find_spec(package)) is not None:
+ try:
+ module = importlib.util.module_from_spec(spec)
+
+ sys.modules[package] = module
+ spec.loader.exec_module(module)
+
+ return True
+ except ImportError:
+ return False
+ else:
+ return False
+
+
+def check_for_update():
+ """Check for updates
+
+ Returns:
+ BOOL: Flag to indicate if there is an update
+ """
+ # Fetch the latest version from the PyPI API
+ response = requests.get("https://pypi.org/pypi/swarms/json")
+ latest_version = response.json()["info"]["version"]
+
+ # Get the current version using pkg_resources
+ current_version = pkg_resources.get_distribution("swarms").version
+
+ return version.parse(latest_version) > version.parse(
+ current_version
+ )
+
+
+# out = check_for_update()
+# print(out)
diff --git a/swarms/telemetry/sys_info.py b/swarms/telemetry/sys_info.py
new file mode 100644
index 00000000..08ad1db3
--- /dev/null
+++ b/swarms/telemetry/sys_info.py
@@ -0,0 +1,158 @@
+import platform
+import subprocess
+
+import pkg_resources
+import psutil
+import toml
+
+
+def get_python_version():
+ return platform.python_version()
+
+
+def get_pip_version():
+ try:
+ pip_version = (
+ subprocess.check_output(["pip", "--version"])
+ .decode()
+ .split()[1]
+ )
+ except Exception as e:
+ pip_version = str(e)
+ return pip_version
+
+
+def get_oi_version():
+ try:
+ oi_version_cmd = (
+ subprocess.check_output(["interpreter", "--version"])
+ .decode()
+ .split()[1]
+ )
+ except Exception as e:
+ oi_version_cmd = str(e)
+ oi_version_pkg = pkg_resources.get_distribution(
+ "open-interpreter"
+ ).version
+ oi_version = oi_version_cmd, oi_version_pkg
+ return oi_version
+
+
+def get_os_version():
+ return platform.platform()
+
+
+def get_cpu_info():
+ return platform.processor()
+
+
+def get_ram_info():
+ vm = psutil.virtual_memory()
+ used_ram_gb = vm.used / (1024**3)
+ free_ram_gb = vm.free / (1024**3)
+ total_ram_gb = vm.total / (1024**3)
+ return (
+ f"{total_ram_gb:.2f} GB, used: {used_ram_gb:.2f}, free:"
+ f" {free_ram_gb:.2f}"
+ )
+
+
+def get_package_mismatches(file_path="pyproject.toml"):
+ with open(file_path, "r") as file:
+ pyproject = toml.load(file)
+ dependencies = pyproject["tool"]["poetry"]["dependencies"]
+ dev_dependencies = pyproject["tool"]["poetry"]["group"]["dev"][
+ "dependencies"
+ ]
+ dependencies.update(dev_dependencies)
+
+ installed_packages = {
+ pkg.key: pkg.version for pkg in pkg_resources.working_set
+ }
+
+ mismatches = []
+ for package, version_info in dependencies.items():
+ if isinstance(version_info, dict):
+ version_info = version_info["version"]
+ installed_version = installed_packages.get(package)
+ if installed_version and version_info.startswith("^"):
+ expected_version = version_info[1:]
+ if not installed_version.startswith(expected_version):
+ mismatches.append(
+ f"\t {package}: Mismatch,"
+ f" pyproject.toml={expected_version},"
+ f" pip={installed_version}"
+ )
+ else:
+ mismatches.append(f"\t {package}: Not found in pip list")
+
+ return "\n" + "\n".join(mismatches)
+
+
+def interpreter_info(interpreter):
+ try:
+ if interpreter.offline and interpreter.llm.api_base:
+ try:
+ curl = subprocess.check_output(
+ f"curl {interpreter.llm.api_base}"
+ )
+ except Exception as e:
+ curl = str(e)
+ else:
+ curl = "Not local"
+
+ messages_to_display = []
+ for message in interpreter.messages:
+ message = message.copy()
+ try:
+ if len(message["content"]) > 600:
+ message["content"] = (
+ message["content"][:300]
+ + "..."
+ + message["content"][-300:]
+ )
+ except Exception as e:
+ print(str(e), "for message:", message)
+ messages_to_display.append(message)
+
+ return f"""
+
+ # Interpreter Info
+
+ Vision: {interpreter.llm.supports_vision}
+ Model: {interpreter.llm.model}
+ Function calling: {interpreter.llm.supports_functions}
+ Context window: {interpreter.llm.context_window}
+ Max tokens: {interpreter.llm.max_tokens}
+
+ Auto run: {interpreter.auto_run}
+ API base: {interpreter.llm.api_base}
+ Offline: {interpreter.offline}
+
+ Curl output: {curl}
+
+ # Messages
+
+ System Message: {interpreter.system_message}
+
+ """ + "\n\n".join([str(m) for m in messages_to_display])
+ except:
+ return "Error, couldn't get interpreter info"
+
+
+def system_info(interpreter):
+ oi_version = get_oi_version()
+ print(f"""
+ Python Version: {get_python_version()}
+ Pip Version: {get_pip_version()}
+ Open-interpreter Version: cmd:{oi_version[0]}, pkg: {oi_version[1]}
+ OS Version and Architecture: {get_os_version()}
+ CPU Info: {get_cpu_info()}
+ RAM Info: {get_ram_info()}
+ {interpreter_info(interpreter)}
+ """)
+
+ # Removed the following, as it causes `FileNotFoundError: [Errno 2] No such file or directory: 'pyproject.toml'`` on prod
+ # (i think it works on dev, but on prod the pyproject.toml will not be in the cwd. might not be accessible at all)
+ # Package Version Mismatches:
+ # {get_package_mismatches()}
diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py
index 72fc7199..c1479507 100644
--- a/swarms/utils/__init__.py
+++ b/swarms/utils/__init__.py
@@ -1,25 +1,30 @@
+from swarms.utils.class_args_wrapper import print_class_parameters
from swarms.utils.code_interpreter import SubprocessCodeInterpreter
-from swarms.utils.markdown_message import display_markdown_message
-from swarms.utils.parse_code import (
- extract_code_in_backticks_in_string,
-)
-from swarms.utils.pdf_to_text import pdf_to_text
-from swarms.utils.math_eval import math_eval
-from swarms.utils.llm_metrics_decorator import metrics_decorator
from swarms.utils.device_checker_cuda import check_device
+from swarms.utils.find_img_path import find_image_path
+from swarms.utils.llm_metrics_decorator import metrics_decorator
from swarms.utils.load_model_torch import load_model_torch
+from swarms.utils.markdown_message import display_markdown_message
+from swarms.utils.math_eval import math_eval
+from swarms.utils.parse_code import extract_code_from_markdown
+from swarms.utils.pdf_to_text import pdf_to_text
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,
)
+from swarms.utils.token_count_tiktoken import limit_tokens_from_string
+
__all__ = [
- "display_markdown_message",
"SubprocessCodeInterpreter",
- "extract_code_in_backticks_in_string",
- "pdf_to_text",
+ "display_markdown_message",
+ "extract_code_from_markdown",
+ "find_image_path",
+ "limit_tokens_from_string",
+ "load_model_torch",
"math_eval",
"metrics_decorator",
- "check_device",
- "load_model_torch",
+ "pdf_to_text",
"prep_torch_inference",
+ "print_class_parameters",
+ "check_device",
]
diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py
index 98fbab70..9e27b668 100644
--- a/swarms/utils/code_interpreter.py
+++ b/swarms/utils/code_interpreter.py
@@ -5,22 +5,7 @@ import time
import traceback
-class BaseCodeInterpreter:
- """
- .run is a generator that yields a dict with attributes: active_line, output
- """
-
- def __init__(self):
- pass
-
- def run(self, code):
- pass
-
- def terminate(self):
- pass
-
-
-class SubprocessCodeInterpreter(BaseCodeInterpreter):
+class SubprocessCodeInterpreter:
"""
SubprocessCodeinterpreter is a base class for code interpreters that run code in a subprocess.
@@ -43,12 +28,36 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
self.done = threading.Event()
def detect_active_line(self, line):
+ """Detect if the line is an active line
+
+ Args:
+ line (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
return None
def detect_end_of_execution(self, line):
+ """detect if the line is an end of execution line
+
+ Args:
+ line (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
return None
def line_postprocessor(self, line):
+ """Line postprocessor
+
+ Args:
+ line (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
return line
def preprocess_code(self, code):
@@ -61,9 +70,11 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
return code
def terminate(self):
+ """terminate the subprocess"""
self.process.terminate()
def start_process(self):
+ """start the subprocess"""
if self.process:
self.terminate()
@@ -88,6 +99,14 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
).start()
def run(self, code: str):
+ """Run the code in the subprocess
+
+ Args:
+ code (str): _description_
+
+ Yields:
+ _type_: _description_
+ """
retry_count = 0
max_retries = 3
@@ -157,6 +176,12 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
break
def handle_stream_output(self, stream, is_error_stream):
+ """Handle the output from the subprocess
+
+ Args:
+ stream (_type_): _description_
+ is_error_stream (bool): _description_
+ """
for line in iter(stream.readline, ""):
if self.debug_mode:
print(f"Received output line:\n{line}\n---")
@@ -179,3 +204,12 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
self.done.set()
else:
self.output_queue.put({"output": line})
+
+
+# interpreter = SubprocessCodeInterpreter()
+# interpreter.start_cmd = "python3"
+# for output in interpreter.run("""
+# print("hello")
+# print("world")
+# """):
+# print(output)
diff --git a/swarms/utils/find_img_path.py b/swarms/utils/find_img_path.py
new file mode 100644
index 00000000..2ca5d082
--- /dev/null
+++ b/swarms/utils/find_img_path.py
@@ -0,0 +1,24 @@
+import os
+import re
+
+
+def find_image_path(text):
+ """Find the image path from the text
+
+ Args:
+ text (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ pattern = r"([A-Za-z]:\\[^:\n]*?\.(png|jpg|jpeg|PNG|JPG|JPEG))|(/[^:\n]*?\.(png|jpg|jpeg|PNG|JPG|JPEG))"
+ matches = [
+ match.group()
+ for match in re.finditer(pattern, text)
+ if match.group()
+ ]
+ matches += [match.replace("\\", "") for match in matches if match]
+ existing_paths = [
+ match for match in matches if os.path.exists(match)
+ ]
+ return max(existing_paths, key=len) if existing_paths else None
diff --git a/swarms/utils/markdown_message.py b/swarms/utils/markdown_message.py
index 0fe9c2c0..57cd285f 100644
--- a/swarms/utils/markdown_message.py
+++ b/swarms/utils/markdown_message.py
@@ -1,23 +1,27 @@
-from rich import print as rich_print
+from rich.console import Console
from rich.markdown import Markdown
from rich.rule import Rule
-def display_markdown_message(message: str):
+def display_markdown_message(message: str, color: str = "cyan"):
"""
Display markdown message. Works with multiline strings with lots of indentation.
Will automatically make single line > tags beautiful.
"""
+ console = Console()
for line in message.split("\n"):
line = line.strip()
if line == "":
- print("")
+ console.print("")
elif line == "---":
- rich_print(Rule(style="white"))
+ console.print(Rule(style=color))
else:
- rich_print(Markdown(line))
+ console.print(Markdown(line, style=color))
if "\n" not in message and message.startswith(">"):
# Aesthetic choice. For these tags, they need a space below them
- print("")
+ console.print("")
+
+
+# display_markdown_message("I love you and you are beautiful.", "cyan")
diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py
index 2e0fa438..838d5868 100644
--- a/swarms/utils/parse_code.py
+++ b/swarms/utils/parse_code.py
@@ -1,31 +1,19 @@
import re
-# def extract_code_in_backticks_in_string(s: str) -> str:
-# """
-# Extracts code blocks from a markdown string.
-# Args:
-# s (str): The markdown string to extract code from.
-
-# Returns:
-# list: A list of tuples. Each tuple contains the language of the code block (if specified) and the code itself.
-# """
-# pattern = r"```([\w\+\#\-\.\s]*)\n(.*?)```"
-# matches = re.findall(pattern, s, re.DOTALL)
-# out = [(match[0], match[1].strip()) for match in matches]
-# print(out)
-
-
-def extract_code_in_backticks_in_string(s: str) -> str:
+def extract_code_from_markdown(markdown_content: str):
"""
- Extracts code blocks from a markdown string.
+ Extracts code blocks from a Markdown string and returns them as a single string.
Args:
- s (str): The markdown string to extract code from.
+ - markdown_content (str): The Markdown content as a string.
Returns:
- str: A string containing all the code blocks.
+ - str: A single string containing all the code blocks separated by newlines.
"""
- pattern = r"```([\w\+\#\-\.\s]*)(.*?)```"
- matches = re.findall(pattern, s, re.DOTALL)
- return "\n".join(match[1].strip() for match in matches)
+ # Regular expression for fenced code blocks
+ pattern = r"```(?:\w+\n)?(.*?)```"
+ matches = re.findall(pattern, markdown_content, re.DOTALL)
+
+ # Concatenate all code blocks separated by newlines
+ return "\n".join(code.strip() for code in matches)
diff --git a/task.py b/task.py
new file mode 100644
index 00000000..089cb263
--- /dev/null
+++ b/task.py
@@ -0,0 +1,47 @@
+from swarms.structs import Task, Agent
+from swarms.models import OpenAIChat
+from dotenv import load_dotenv
+import os
+
+
+# Load the environment variables
+load_dotenv()
+
+
+# Define a function to be used as the action
+def my_action():
+ print("Action executed")
+
+
+# Define a function to be used as the condition
+def my_condition():
+ print("Condition checked")
+ return True
+
+
+# Create an agent
+agent = Agent(
+ llm=OpenAIChat(openai_api_key=os.environ["OPENAI_API_KEY"]),
+ max_loops=1,
+ dashboard=False,
+)
+
+# Create a task
+task = Task(description="What's the weather in miami", agent=agent)
+
+# Set the action and condition
+task.set_action(my_action)
+task.set_condition(my_condition)
+
+# Execute the task
+print("Executing task...")
+task.run()
+
+# Check if the task is completed
+if task.is_completed():
+ print("Task completed")
+else:
+ print("Task not completed")
+
+# Output the result of the task
+print(f"Task result: {task.result}")
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/models/test_multion.py b/tests/models/test_multion.py
deleted file mode 100644
index cc91b421..00000000
--- a/tests/models/test_multion.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import pytest
-from unittest.mock import Mock, patch
-from swarms.models.multion import MultiOn
-
-
-@pytest.fixture
-def multion_instance():
- return MultiOn()
-
-
-@pytest.fixture
-def mock_multion():
- return Mock()
-
-
-def test_multion_import():
- with pytest.raises(ImportError):
- pass
-
-
-def test_multion_init():
- multion = MultiOn()
- assert isinstance(multion, MultiOn)
-
-
-def test_multion_run_with_valid_input(multion_instance, mock_multion):
- task = "Order chicken tendies"
- url = "https://www.google.com/"
- mock_multion.new_session.return_value = (
- "Order chicken tendies. https://www.google.com/"
- )
-
- with patch("swarms.models.multion.multion", mock_multion):
- response = multion_instance.run(task, url)
-
- assert (
- response == "Order chicken tendies. https://www.google.com/"
- )
-
-
-def test_multion_run_with_invalid_input(
- multion_instance, mock_multion
-):
- task = ""
- url = "https://www.google.com/"
- mock_multion.new_session.return_value = None
-
- with patch("swarms.models.multion.multion", mock_multion):
- response = multion_instance.run(task, url)
-
- assert response is None
-
-
-# Add more test cases to cover different scenarios, edge cases, and error handling as needed.
diff --git a/tests/models/test_open_dalle.py b/tests/models/test_open_dalle.py
new file mode 100644
index 00000000..2483d705
--- /dev/null
+++ b/tests/models/test_open_dalle.py
@@ -0,0 +1,59 @@
+import pytest
+import torch
+from swarms.models.open_dalle import OpenDalle
+
+
+def test_init():
+ od = OpenDalle()
+ assert isinstance(od, OpenDalle)
+
+
+def test_init_custom_model():
+ od = OpenDalle(model_name="custom_model")
+ assert od.pipeline.model_name == "custom_model"
+
+
+def test_init_custom_dtype():
+ od = OpenDalle(torch_dtype=torch.float32)
+ assert od.pipeline.torch_dtype == torch.float32
+
+
+def test_init_custom_device():
+ od = OpenDalle(device="cpu")
+ assert od.pipeline.device == "cpu"
+
+
+def test_run():
+ od = OpenDalle()
+ result = od.run("A picture of a cat")
+ assert isinstance(result, torch.Tensor)
+
+
+def test_run_no_task():
+ od = OpenDalle()
+ with pytest.raises(ValueError, match="Task cannot be None"):
+ od.run(None)
+
+
+def test_run_non_string_task():
+ od = OpenDalle()
+ with pytest.raises(TypeError, match="Task must be a string"):
+ od.run(123)
+
+
+def test_run_empty_task():
+ od = OpenDalle()
+ with pytest.raises(ValueError, match="Task cannot be empty"):
+ od.run("")
+
+
+def test_run_custom_args():
+ od = OpenDalle()
+ result = od.run("A picture of a cat", custom_arg="custom_value")
+ assert isinstance(result, torch.Tensor)
+
+
+def test_run_error():
+ od = OpenDalle()
+ with pytest.raises(Exception):
+ od.run("A picture of a cat", raise_error=True)
diff --git a/tests/models/test_ssd_1b.py b/tests/models/test_ssd_1b.py
index 35cc4864..39e4264e 100644
--- a/tests/models/test_ssd_1b.py
+++ b/tests/models/test_ssd_1b.py
@@ -162,74 +162,3 @@ def test_ssd1b_repr_str(ssd1b_model):
image_url = ssd1b_model(task)
assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
-
-
-import pytest
-from your_module import SSD1B
-
-
-# Create fixtures if needed
-@pytest.fixture
-def ssd1b_model():
- return SSD1B()
-
-
-# Test cases for additional scenarios and behaviors
-def test_ssd1b_dashboard_printing(ssd1b_model, capsys):
- ssd1b_model.dashboard = True
- ssd1b_model.print_dashboard()
- captured = capsys.readouterr()
- assert "SSD1B Dashboard:" in captured.out
-
-
-def test_ssd1b_generate_image_name(ssd1b_model):
- task = "A painting of a dog"
- img_name = ssd1b_model._generate_image_name(task)
- assert isinstance(img_name, str)
- assert len(img_name) > 0
-
-
-def test_ssd1b_set_width_height(ssd1b_model, mocker):
- img = mocker.MagicMock()
- width, height = 800, 600
- result = ssd1b_model.set_width_height(img, width, height)
- assert result == img.resize.return_value
-
-
-def test_ssd1b_read_img(ssd1b_model, mocker):
- img = mocker.MagicMock()
- result = ssd1b_model.read_img(img)
- assert result == img.open.return_value
-
-
-def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker):
- img = mocker.MagicMock()
- img_format = "PNG"
- result = ssd1b_model.convert_to_bytesio(img, img_format)
- assert isinstance(result, bytes)
-
-
-def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path):
- img = mocker.MagicMock()
- img_name = "test.png"
- save_path = tmp_path / img_name
- ssd1b_model._download_image(img, img_name, save_path)
- assert save_path.exists()
-
-
-def test_ssd1b_repr_str(ssd1b_model):
- task = "A painting of a dog"
- image_url = ssd1b_model(task)
- assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
- assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
-
-
-def test_ssd1b_rate_limited_call(ssd1b_model, mocker):
- task = "A painting of a dog"
- mocker.patch.object(
- ssd1b_model,
- "__call__",
- side_effect=Exception("Rate limit exceeded"),
- )
- with pytest.raises(Exception, match="Rate limit exceeded"):
- ssd1b_model.rate_limited_call(task)
diff --git a/tests/models/test_zeroscope.py b/tests/models/test_zeroscope.py
new file mode 100644
index 00000000..25a4c597
--- /dev/null
+++ b/tests/models/test_zeroscope.py
@@ -0,0 +1,122 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from swarms.models.zeroscope import ZeroscopeTTV
+
+
+@patch("swarms.models.zeroscope.DiffusionPipeline")
+@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
+def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline):
+ zeroscope = ZeroscopeTTV()
+ mock_pipeline.from_pretrained.assert_called_once()
+ mock_scheduler.assert_called_once()
+ assert zeroscope.model_name == "cerspense/zeroscope_v2_576w"
+ assert zeroscope.chunk_size == 1
+ assert zeroscope.dim == 1
+ assert zeroscope.num_inference_steps == 40
+ assert zeroscope.height == 320
+ assert zeroscope.width == 576
+ assert zeroscope.num_frames == 36
+
+
+@patch("swarms.models.zeroscope.DiffusionPipeline")
+@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
+def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
+ zeroscope = ZeroscopeTTV()
+ mock_pipeline_instance = MagicMock()
+ mock_pipeline.from_pretrained.return_value = (
+ mock_pipeline_instance
+ )
+ mock_pipeline_instance.return_value = MagicMock(
+ frames="Generated frames"
+ )
+ mock_pipeline_instance.enable_vae_slicing.assert_called_once()
+ mock_pipeline_instance.enable_forward_chunking.assert_called_once_with(
+ chunk_size=1, dim=1
+ )
+ result = zeroscope.forward("Test task")
+ assert result == "Generated frames"
+ mock_pipeline_instance.assert_called_once_with(
+ "Test task",
+ num_inference_steps=40,
+ height=320,
+ width=576,
+ num_frames=36,
+ )
+
+
+@patch("swarms.models.zeroscope.DiffusionPipeline")
+@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
+def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
+ zeroscope = ZeroscopeTTV()
+ mock_pipeline_instance = MagicMock()
+ mock_pipeline.from_pretrained.return_value = (
+ mock_pipeline_instance
+ )
+ mock_pipeline_instance.return_value = MagicMock(
+ frames="Generated frames"
+ )
+ mock_pipeline_instance.side_effect = Exception("Test error")
+ with pytest.raises(Exception, match="Test error"):
+ zeroscope.forward("Test task")
+
+
+@patch("swarms.models.zeroscope.DiffusionPipeline")
+@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
+def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
+ zeroscope = ZeroscopeTTV()
+ mock_pipeline_instance = MagicMock()
+ mock_pipeline.from_pretrained.return_value = (
+ mock_pipeline_instance
+ )
+ mock_pipeline_instance.return_value = MagicMock(
+ frames="Generated frames"
+ )
+ result = zeroscope.__call__("Test task")
+ assert result == "Generated frames"
+ mock_pipeline_instance.assert_called_once_with(
+ "Test task",
+ num_inference_steps=40,
+ height=320,
+ width=576,
+ num_frames=36,
+ )
+
+
+@patch("swarms.models.zeroscope.DiffusionPipeline")
+@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
+def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
+ zeroscope = ZeroscopeTTV()
+ mock_pipeline_instance = MagicMock()
+ mock_pipeline.from_pretrained.return_value = (
+ mock_pipeline_instance
+ )
+ mock_pipeline_instance.return_value = MagicMock(
+ frames="Generated frames"
+ )
+ mock_pipeline_instance.side_effect = Exception("Test error")
+ with pytest.raises(Exception, match="Test error"):
+ zeroscope.__call__("Test task")
+
+
+@patch("swarms.models.zeroscope.DiffusionPipeline")
+@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
+def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline):
+ zeroscope = ZeroscopeTTV()
+ mock_pipeline_instance = MagicMock()
+ mock_pipeline.from_pretrained.return_value = (
+ mock_pipeline_instance
+ )
+ mock_pipeline_instance.return_value = MagicMock(
+ frames="Generated frames"
+ )
+ result = zeroscope.save_video_path("Test video path")
+ assert result == "Test video path"
+ mock_pipeline_instance.assert_called_once_with(
+ "Test video path",
+ num_inference_steps=40,
+ height=320,
+ width=576,
+ num_frames=36,
+ )
diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py
index 62abeede..ac3da51a 100644
--- a/tests/structs/test_autoscaler.py
+++ b/tests/structs/test_autoscaler.py
@@ -216,3 +216,64 @@ def test_add_task_exception(mock_put):
with pytest.raises(Exception) as e:
autoscaler.add_task("test_task")
assert str(e.value) == "test error"
+
+
+def test_autoscaler_initialization():
+ autoscaler = AutoScaler(
+ initial_agents=5,
+ scale_up_factor=2,
+ idle_threshold=0.1,
+ busy_threshold=0.8,
+ agent=agent,
+ )
+ assert isinstance(autoscaler, AutoScaler)
+ assert autoscaler.scale_up_factor == 2
+ assert autoscaler.idle_threshold == 0.1
+ assert autoscaler.busy_threshold == 0.8
+ assert len(autoscaler.agents_pool) == 5
+
+
+def test_autoscaler_add_task():
+ autoscaler = AutoScaler(agent=agent)
+ autoscaler.add_task("task1")
+ assert autoscaler.task_queue.qsize() == 1
+
+
+def test_autoscaler_scale_up():
+ autoscaler = AutoScaler(
+ initial_agents=5, scale_up_factor=2, agent=agent
+ )
+ autoscaler.scale_up()
+ assert len(autoscaler.agents_pool) == 10
+
+
+def test_autoscaler_scale_down():
+ autoscaler = AutoScaler(initial_agents=5, agent=agent)
+ autoscaler.scale_down()
+ assert len(autoscaler.agents_pool) == 4
+
+
+@patch("swarms.swarms.AutoScaler.scale_up")
+@patch("swarms.swarms.AutoScaler.scale_down")
+def test_autoscaler_monitor_and_scale(mock_scale_down, mock_scale_up):
+ autoscaler = AutoScaler(initial_agents=5, agent=agent)
+ autoscaler.add_task("task1")
+ autoscaler.monitor_and_scale()
+ mock_scale_up.assert_called_once()
+ mock_scale_down.assert_called_once()
+
+
+@patch("swarms.swarms.AutoScaler.monitor_and_scale")
+@patch("swarms.swarms.agent.run")
+def test_autoscaler_start(mock_run, mock_monitor_and_scale):
+ autoscaler = AutoScaler(initial_agents=5, agent=agent)
+ autoscaler.add_task("task1")
+ autoscaler.start()
+ mock_run.assert_called_once()
+ mock_monitor_and_scale.assert_called_once()
+
+
+def test_autoscaler_del_agent():
+ autoscaler = AutoScaler(initial_agents=5, agent=agent)
+ autoscaler.del_agent()
+ assert len(autoscaler.agents_pool) == 4
diff --git a/tests/structs/test_concurrent_workflow.py b/tests/structs/test_concurrent_workflow.py
new file mode 100644
index 00000000..206e8e2a
--- /dev/null
+++ b/tests/structs/test_concurrent_workflow.py
@@ -0,0 +1,56 @@
+from unittest.mock import Mock, create_autospec, patch
+from concurrent.futures import Future
+from swarms.structs import ConcurrentWorkflow, Task, Agent
+
+
+def test_add():
+ workflow = ConcurrentWorkflow(max_workers=2)
+ task = Mock(spec=Task)
+ workflow.add(task)
+ assert task in workflow.tasks
+
+
+def test_run():
+ workflow = ConcurrentWorkflow(max_workers=2)
+ task1 = create_autospec(Task)
+ task2 = create_autospec(Task)
+ workflow.add(task1)
+ workflow.add(task2)
+
+ with patch(
+ "concurrent.futures.ThreadPoolExecutor"
+ ) as mock_executor:
+ future1 = Future()
+ future1.set_result(None)
+ future2 = Future()
+ future2.set_result(None)
+
+ mock_executor.return_value.__enter__.return_value.submit.side_effect = [
+ future1,
+ future2,
+ ]
+ mock_executor.return_value.__enter__.return_value.as_completed.return_value = [
+ future1,
+ future2,
+ ]
+
+ workflow.run()
+
+ task1.execute.assert_called_once()
+ task2.execute.assert_called_once()
+
+
+def test_execute_task():
+ workflow = ConcurrentWorkflow(max_workers=2)
+ task = create_autospec(Task)
+ workflow._execute_task(task)
+ task.execute.assert_called_once()
+
+
+def test_agent_execution():
+ workflow = ConcurrentWorkflow(max_workers=2)
+ agent = create_autospec(Agent)
+ task = Task(agent)
+ workflow.add(task)
+ workflow._execute_task(task)
+ agent.execute.assert_called_once()
diff --git a/tests/swarms/test_groupchat.py b/tests/structs/test_groupchat.py
similarity index 99%
rename from tests/swarms/test_groupchat.py
rename to tests/structs/test_groupchat.py
index ce17a4d2..e8096d9c 100644
--- a/tests/swarms/test_groupchat.py
+++ b/tests/structs/test_groupchat.py
@@ -3,7 +3,7 @@ import pytest
from swarms.models import OpenAIChat
from swarms.models.anthropic import Anthropic
from swarms.structs.agent import Agent
-from swarms.swarms.groupchat import GroupChat, GroupChatManager
+from swarms.structs.groupchat import GroupChat, GroupChatManager
llm = OpenAIChat()
llm2 = Anthropic()
diff --git a/tests/structs/test_model_parallizer.py b/tests/structs/test_model_parallizer.py
new file mode 100644
index 00000000..37ca43db
--- /dev/null
+++ b/tests/structs/test_model_parallizer.py
@@ -0,0 +1,146 @@
+import pytest
+from swarms.structs.model_parallizer import ModelParallelizer
+from swarms.models import (
+ HuggingfaceLLM,
+ Mixtral,
+ GPT4VisionAPI,
+ ZeroscopeTTV,
+)
+
+# Initialize the models
+custom_config = {
+ "quantize": True,
+ "quantization_config": {"load_in_4bit": True},
+ "verbose": True,
+}
+huggingface_llm = HuggingfaceLLM(
+ model_id="NousResearch/Nous-Hermes-2-Vision-Alpha",
+ **custom_config,
+)
+mixtral = Mixtral(load_in_4bit=True, use_flash_attention_2=True)
+gpt4_vision_api = GPT4VisionAPI(max_tokens=1000)
+zeroscope_ttv = ZeroscopeTTV()
+
+
+def test_init():
+ mp = ModelParallelizer(
+ [
+ huggingface_llm,
+ mixtral,
+ gpt4_vision_api,
+ zeroscope_ttv,
+ ]
+ )
+ assert isinstance(mp, ModelParallelizer)
+
+
+def test_run():
+ mp = ModelParallelizer([huggingface_llm])
+ result = mp.run(
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ )
+ assert isinstance(result, str)
+
+
+def test_run_all():
+ mp = ModelParallelizer(
+ [
+ huggingface_llm,
+ mixtral,
+ gpt4_vision_api,
+ zeroscope_ttv,
+ ]
+ )
+ result = mp.run_all(
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ )
+ assert isinstance(result, list)
+ assert len(result) == 5
+
+
+def test_add_llm():
+ mp = ModelParallelizer([huggingface_llm])
+ mp.add_llm(mixtral)
+ assert len(mp.llms) == 2
+
+
+def test_remove_llm():
+ mp = ModelParallelizer([huggingface_llm, mixtral])
+ mp.remove_llm(mixtral)
+ assert len(mp.llms) == 1
+
+
+def test_save_responses_to_file(tmp_path):
+ mp = ModelParallelizer([huggingface_llm])
+ mp.run(
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ )
+ file = tmp_path / "responses.txt"
+ mp.save_responses_to_file(file)
+ assert file.read_text() != ""
+
+
+def test_get_task_history():
+ mp = ModelParallelizer([huggingface_llm])
+ mp.run(
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ )
+ assert mp.get_task_history() == [
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ ]
+
+
+def test_summary(capsys):
+ mp = ModelParallelizer([huggingface_llm])
+ mp.run(
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ )
+ mp.summary()
+ captured = capsys.readouterr()
+ assert "Tasks History:" in captured.out
+
+
+def test_enable_load_balancing():
+ mp = ModelParallelizer([huggingface_llm])
+ mp.enable_load_balancing()
+ assert mp.load_balancing is True
+
+
+def test_disable_load_balancing():
+ mp = ModelParallelizer([huggingface_llm])
+ mp.disable_load_balancing()
+ assert mp.load_balancing is False
+
+
+def test_concurrent_run():
+ mp = ModelParallelizer([huggingface_llm, mixtral])
+ result = mp.concurrent_run(
+ "Create a list of known biggest risks of structural collapse"
+ " with references"
+ )
+ assert isinstance(result, list)
+ assert len(result) == 2
+
+
+def test_concurrent_run_no_task():
+ mp = ModelParallelizer([huggingface_llm])
+ with pytest.raises(TypeError):
+ mp.concurrent_run()
+
+
+def test_concurrent_run_non_string_task():
+ mp = ModelParallelizer([huggingface_llm])
+ with pytest.raises(TypeError):
+ mp.concurrent_run(123)
+
+
+def test_concurrent_run_empty_task():
+ mp = ModelParallelizer([huggingface_llm])
+ result = mp.concurrent_run("")
+ assert result == [""]
diff --git a/tests/swarms/test_multi_agent_collab.py b/tests/structs/test_multi_agent_collab.py
similarity index 98%
rename from tests/swarms/test_multi_agent_collab.py
rename to tests/structs/test_multi_agent_collab.py
index 4d85a436..05b914b4 100644
--- a/tests/swarms/test_multi_agent_collab.py
+++ b/tests/structs/test_multi_agent_collab.py
@@ -4,7 +4,7 @@ import pytest
from unittest.mock import Mock
from swarms.structs import Agent
from swarms.models import OpenAIChat
-from swarms.swarms.multi_agent_collab import (
+from swarms.structs.multi_agent_collab import (
MultiAgentCollaboration,
)
diff --git a/tests/structs/test_swarmnetwork.py b/tests/structs/test_swarmnetwork.py
new file mode 100644
index 00000000..9264ee8d
--- /dev/null
+++ b/tests/structs/test_swarmnetwork.py
@@ -0,0 +1,50 @@
+import pytest
+from unittest.mock import Mock, patch
+from swarms.structs.swarm_net import SwarmNetwork
+from swarms.structs.agent import Agent
+
+
+@pytest.fixture
+def swarm_network():
+ agents = [Agent(id=f"Agent_{i}") for i in range(5)]
+ return SwarmNetwork(agents=agents)
+
+
+def test_swarm_network_init(swarm_network):
+ assert isinstance(swarm_network.agents, list)
+ assert len(swarm_network.agents) == 5
+
+
+@patch("swarms.structs.swarm_net.SwarmNetwork.logger")
+def test_run(mock_logger, swarm_network):
+ swarm_network.run()
+ assert (
+ mock_logger.info.call_count == 10
+ ) # 2 log messages per agent
+
+
+def test_run_with_mocked_agents(mocker, swarm_network):
+ mock_agents = [Mock(spec=Agent) for _ in range(5)]
+ mocker.patch.object(swarm_network, "agents", mock_agents)
+ swarm_network.run()
+ for mock_agent in mock_agents:
+ assert mock_agent.run.called
+
+
+def test_swarm_network_with_no_agents():
+ swarm_network = SwarmNetwork(agents=[])
+ assert swarm_network.agents == []
+
+
+def test_swarm_network_add_agent(swarm_network):
+ new_agent = Agent(id="Agent_5")
+ swarm_network.add_agent(new_agent)
+ assert len(swarm_network.agents) == 6
+ assert swarm_network.agents[-1] == new_agent
+
+
+def test_swarm_network_remove_agent(swarm_network):
+ agent_to_remove = swarm_network.agents[0]
+ swarm_network.remove_agent(agent_to_remove)
+ assert len(swarm_network.agents) == 4
+ assert agent_to_remove not in swarm_network.agents
diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py
index fada564a..8a76549c 100644
--- a/tests/structs/test_task.py
+++ b/tests/structs/test_task.py
@@ -9,6 +9,8 @@ from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
)
from swarms.structs.agent import Agent
from swarms.structs.task import Task
+import datetime
+from datetime import timedelta
load_dotenv()
@@ -163,3 +165,119 @@ def test_execute():
task = Task(id="5", task="Task5", result=None, agents=[agent])
# Assuming execute method returns True on successful execution
assert task.execute() is True
+
+
+def test_task_execute_with_agent(mocker):
+ mock_agent = mocker.Mock(spec=Agent)
+ mock_agent.run.return_value = "result"
+ task = Task(description="Test task", agent=mock_agent)
+ task.execute()
+ assert task.result == "result"
+ assert task.history == ["result"]
+
+
+def test_task_execute_with_callable(mocker):
+ mock_callable = mocker.Mock()
+ mock_callable.run.return_value = "result"
+ task = Task(description="Test task", agent=mock_callable)
+ task.execute()
+ assert task.result == "result"
+ assert task.history == ["result"]
+
+
+def test_task_execute_with_condition(mocker):
+ mock_agent = mocker.Mock(spec=Agent)
+ mock_agent.run.return_value = "result"
+ condition = mocker.Mock(return_value=True)
+ task = Task(
+ description="Test task", agent=mock_agent, condition=condition
+ )
+ task.execute()
+ assert task.result == "result"
+ assert task.history == ["result"]
+
+
+def test_task_execute_with_condition_false(mocker):
+ mock_agent = mocker.Mock(spec=Agent)
+ mock_agent.run.return_value = "result"
+ condition = mocker.Mock(return_value=False)
+ task = Task(
+ description="Test task", agent=mock_agent, condition=condition
+ )
+ task.execute()
+ assert task.result is None
+ assert task.history == []
+
+
+def test_task_execute_with_action(mocker):
+ mock_agent = mocker.Mock(spec=Agent)
+ mock_agent.run.return_value = "result"
+ action = mocker.Mock()
+ task = Task(
+ description="Test task", agent=mock_agent, action=action
+ )
+ task.execute()
+ assert task.result == "result"
+ assert task.history == ["result"]
+ action.assert_called_once()
+
+
+def test_task_handle_scheduled_task_now(mocker):
+ mock_agent = mocker.Mock(spec=Agent)
+ mock_agent.run.return_value = "result"
+ task = Task(
+ description="Test task",
+ agent=mock_agent,
+ schedule_time=datetime.now(),
+ )
+ task.handle_scheduled_task()
+ assert task.result == "result"
+ assert task.history == ["result"]
+
+
+def test_task_handle_scheduled_task_future(mocker):
+ mock_agent = mocker.Mock(spec=Agent)
+ mock_agent.run.return_value = "result"
+ task = Task(
+ description="Test task",
+ agent=mock_agent,
+ schedule_time=datetime.now() + timedelta(days=1),
+ )
+ with mocker.patch.object(
+ task.scheduler, "enter"
+ ) as mock_enter, mocker.patch.object(
+ task.scheduler, "run"
+ ) as mock_run:
+ task.handle_scheduled_task()
+ mock_enter.assert_called_once()
+ mock_run.assert_called_once()
+
+
+def test_task_set_trigger():
+ task = Task(description="Test task", agent=Agent())
+
+ def trigger():
+ return True
+
+ task.set_trigger(trigger)
+ assert task.trigger == trigger
+
+
+def test_task_set_action():
+ task = Task(description="Test task", agent=Agent())
+
+ def action():
+ return True
+
+ task.set_action(action)
+ assert task.action == action
+
+
+def test_task_set_condition():
+ task = Task(description="Test task", agent=Agent())
+
+ def condition():
+ return True
+
+ task.set_condition(condition)
+ assert task.condition == condition
diff --git a/tests/structs/test_team.py b/tests/structs/test_team.py
new file mode 100644
index 00000000..44d64e18
--- /dev/null
+++ b/tests/structs/test_team.py
@@ -0,0 +1,52 @@
+import json
+import unittest
+
+from swarms.models import OpenAIChat
+from swarms.structs import Agent, Task
+from swarms.structs.team import Team
+
+
+class TestTeam(unittest.TestCase):
+ def setUp(self):
+ self.agent = Agent(
+ llm=OpenAIChat(openai_api_key=""),
+ max_loops=1,
+ dashboard=False,
+ )
+ self.task = Task(
+ description="What's the weather in miami",
+ agent=self.agent,
+ )
+ self.team = Team(
+ tasks=[self.task],
+ agents=[self.agent],
+ architecture="sequential",
+ verbose=False,
+ )
+
+ def test_check_config(self):
+ with self.assertRaises(ValueError):
+ self.team.check_config({"config": None})
+
+ with self.assertRaises(ValueError):
+ self.team.check_config(
+ {"config": json.dumps({"agents": [], "tasks": []})}
+ )
+
+ def test_run(self):
+ self.assertEqual(self.team.run(), self.task.execute())
+
+ def test_sequential_loop(self):
+ self.assertEqual(
+ self.team._Team__sequential_loop(), self.task.execute()
+ )
+
+ def test_log(self):
+ self.assertIsNone(self.team._Team__log("Test message"))
+
+ self.team.verbose = True
+ self.assertIsNone(self.team._Team__log("Test message"))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/swarms/test_autoscaler.py b/tests/swarms/test_autoscaler.py
deleted file mode 100644
index fbf63637..00000000
--- a/tests/swarms/test_autoscaler.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from unittest.mock import patch
-from swarms.structs.autoscaler import AutoScaler
-from swarms.models import OpenAIChat
-from swarms.structs import Agent
-
-llm = OpenAIChat()
-
-agent = Agent(
- llm=llm,
- max_loops=2,
- dashboard=True,
-)
-
-
-def test_autoscaler_initialization():
- autoscaler = AutoScaler(
- initial_agents=5,
- scale_up_factor=2,
- idle_threshold=0.1,
- busy_threshold=0.8,
- agent=agent,
- )
- assert isinstance(autoscaler, AutoScaler)
- assert autoscaler.scale_up_factor == 2
- assert autoscaler.idle_threshold == 0.1
- assert autoscaler.busy_threshold == 0.8
- assert len(autoscaler.agents_pool) == 5
-
-
-def test_autoscaler_add_task():
- autoscaler = AutoScaler(agent=agent)
- autoscaler.add_task("task1")
- assert autoscaler.task_queue.qsize() == 1
-
-
-def test_autoscaler_scale_up():
- autoscaler = AutoScaler(
- initial_agents=5, scale_up_factor=2, agent=agent
- )
- autoscaler.scale_up()
- assert len(autoscaler.agents_pool) == 10
-
-
-def test_autoscaler_scale_down():
- autoscaler = AutoScaler(initial_agents=5, agent=agent)
- autoscaler.scale_down()
- assert len(autoscaler.agents_pool) == 4
-
-
-@patch("swarms.swarms.AutoScaler.scale_up")
-@patch("swarms.swarms.AutoScaler.scale_down")
-def test_autoscaler_monitor_and_scale(mock_scale_down, mock_scale_up):
- autoscaler = AutoScaler(initial_agents=5, agent=agent)
- autoscaler.add_task("task1")
- autoscaler.monitor_and_scale()
- mock_scale_up.assert_called_once()
- mock_scale_down.assert_called_once()
-
-
-@patch("swarms.swarms.AutoScaler.monitor_and_scale")
-@patch("swarms.swarms.agent.run")
-def test_autoscaler_start(mock_run, mock_monitor_and_scale):
- autoscaler = AutoScaler(initial_agents=5, agent=agent)
- autoscaler.add_task("task1")
- autoscaler.start()
- mock_run.assert_called_once()
- mock_monitor_and_scale.assert_called_once()
-
-
-def test_autoscaler_del_agent():
- autoscaler = AutoScaler(initial_agents=5, agent=agent)
- autoscaler.del_agent()
- assert len(autoscaler.agents_pool) == 4
diff --git a/tests/tools/test_base.py b/tests/tools/test_tools_base.py
similarity index 100%
rename from tests/tools/test_base.py
rename to tests/tools/test_tools_base.py
diff --git a/tests/utils/test_check_device.py b/tests/utils/test_check_device.py
new file mode 100644
index 00000000..d542803a
--- /dev/null
+++ b/tests/utils/test_check_device.py
@@ -0,0 +1,64 @@
+import torch
+import logging
+from swarms.utils import check_device
+
+# For the purpose of the test, we're assuming that the `memory_allocated`
+# and `memory_reserved` function behave the same as `torch.cuda.memory_allocated`
+# and `torch.cuda.memory_reserved`
+
+
+def test_check_device_no_cuda(monkeypatch):
+ # Mock torch.cuda.is_available to always return False
+ monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
+
+ result = check_device(log_level=logging.DEBUG)
+ assert result.type == "cpu"
+
+
+def test_check_device_cuda_exception(monkeypatch):
+ # Mock torch.cuda.is_available to raise an exception
+ monkeypatch.setattr(
+ torch.cuda, "is_available", lambda: 1 / 0
+ ) # Raises ZeroDivisionError
+
+ result = check_device(log_level=logging.DEBUG)
+ assert result.type == "cpu"
+
+
+def test_check_device_one_cuda(monkeypatch):
+ # Mock torch.cuda.is_available to return True
+ monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
+ # Mock torch.cuda.device_count to return 1
+ monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
+ # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0
+ monkeypatch.setattr(
+ torch.cuda, "memory_allocated", lambda device: 0
+ )
+ monkeypatch.setattr(
+ torch.cuda, "memory_reserved", lambda device: 0
+ )
+
+ result = check_device(log_level=logging.DEBUG)
+ assert len(result) == 1
+ assert result[0].type == "cuda"
+ assert result[0].index == 0
+
+
+def test_check_device_multiple_cuda(monkeypatch):
+ # Mock torch.cuda.is_available to return True
+ monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
+ # Mock torch.cuda.device_count to return 4
+ monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
+ # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0
+ monkeypatch.setattr(
+ torch.cuda, "memory_allocated", lambda device: 0
+ )
+ monkeypatch.setattr(
+ torch.cuda, "memory_reserved", lambda device: 0
+ )
+
+ result = check_device(log_level=logging.DEBUG)
+ assert len(result) == 4
+ for i in range(4):
+ assert result[i].type == "cuda"
+ assert result[i].index == i
diff --git a/tests/utils/test_class_args_wrapper.py b/tests/utils/test_class_args_wrapper.py
index d846f786..a222ffe9 100644
--- a/tests/utils/test_class_args_wrapper.py
+++ b/tests/utils/test_class_args_wrapper.py
@@ -2,11 +2,9 @@ import pytest
from io import StringIO
from contextlib import redirect_stdout
from swarms.utils.class_args_wrapper import print_class_parameters
-from swarms.structs import Agent, Autoscaler
+from swarms.structs.agent import Agent
from fastapi import FastAPI
from fastapi.testclient import TestClient
-from swarms.utils.class_args_wrapper import print_class_parameters
-from swarms.structs import Agent, Autoscaler
app = FastAPI()
@@ -24,19 +22,6 @@ def test_print_class_parameters_agent():
assert output == expected_output
-def test_print_class_parameters_autoscaler():
- f = StringIO()
- with redirect_stdout(f):
- print_class_parameters(Autoscaler)
- output = f.getvalue().strip()
- # Replace with the expected output for Autoscaler class
- expected_output = (
- "Parameter: min_agents, Type: \nParameter:"
- " max_agents, Type: "
- )
- assert output == expected_output
-
-
def test_print_class_parameters_error():
with pytest.raises(TypeError):
print_class_parameters("Not a class")
@@ -44,7 +29,7 @@ def test_print_class_parameters_error():
@app.get("/parameters/{class_name}")
def get_parameters(class_name: str):
- classes = {"Agent": Agent, "Autoscaler": Autoscaler}
+ classes = {"Agent": Agent}
if class_name in classes:
return print_class_parameters(
classes[class_name], api_format=True
@@ -64,17 +49,6 @@ def test_get_parameters_agent():
assert response.json() == expected_output
-def test_get_parameters_autoscaler():
- response = client.get("/parameters/Autoscaler")
- assert response.status_code == 200
- # Replace with the expected output for Autoscaler class
- expected_output = {
- "min_agents": "",
- "max_agents": "",
- }
- assert response.json() == expected_output
-
-
def test_get_parameters_not_found():
response = client.get("/parameters/NonexistentClass")
assert response.status_code == 200
diff --git a/tests/utils/test_display_markdown_message.py b/tests/utils/test_display_markdown_message.py
new file mode 100644
index 00000000..048038b2
--- /dev/null
+++ b/tests/utils/test_display_markdown_message.py
@@ -0,0 +1,65 @@
+# import necessary modules
+import pytest
+from swarms.utils import display_markdown_message
+from rich.console import Console
+from rich.markdown import Markdown
+from rich.rule import Rule
+from unittest import mock
+
+
+def test_basic_message():
+ # Test basic message functionality
+ with mock.patch.object(Console, "print") as mock_print:
+ display_markdown_message("This is a test")
+ mock_print.assert_called_once_with(
+ Markdown("This is a test", style="cyan")
+ )
+
+
+def test_empty_message():
+ # Test how function handles empty input
+ with mock.patch.object(Console, "print") as mock_print:
+ display_markdown_message("")
+ mock_print.assert_called_once_with("")
+
+
+@pytest.mark.parametrize("color", ["cyan", "red", "blue"])
+def test_colors(color):
+ # Test different colors
+ with mock.patch.object(Console, "print") as mock_print:
+ display_markdown_message("This is a test", color)
+ mock_print.assert_called_once_with(
+ Markdown("This is a test", style=color)
+ )
+
+
+def test_dash_line():
+ # Test how function handles "---"
+ with mock.patch.object(Console, "print") as mock_print:
+ display_markdown_message("---")
+ mock_print.assert_called_once_with(Rule(style="cyan"))
+
+
+def test_message_with_whitespace():
+ # Test how function handles message with whitespaces
+ with mock.patch.object(Console, "print") as mock_print:
+ display_markdown_message(" \n Test \n --- \n Test \n")
+ calls = [
+ mock.call(""),
+ mock.call(Markdown("Test", style="cyan")),
+ mock.call(Rule(style="cyan")),
+ mock.call(Markdown("Test", style="cyan")),
+ mock.call(""),
+ ]
+ mock_print.assert_has_calls(calls)
+
+
+def test_message_start_with_greater_than():
+ # Test how function handles message line starting with ">"
+ with mock.patch.object(Console, "print") as mock_print:
+ display_markdown_message(">This is a test")
+ calls = [
+ mock.call(Markdown(">This is a test", style="cyan")),
+ mock.call(""),
+ ]
+ mock_print.assert_has_calls(calls)
diff --git a/tests/utils/test_extract_code_from_markdown.py b/tests/utils/test_extract_code_from_markdown.py
new file mode 100644
index 00000000..9d37fc94
--- /dev/null
+++ b/tests/utils/test_extract_code_from_markdown.py
@@ -0,0 +1,47 @@
+import pytest
+from swarms.utils import extract_code_from_markdown
+
+
+@pytest.fixture
+def markdown_content_with_code():
+ return """
+ # This is a markdown document
+
+ Some intro text here.
+Some additional text.
+"""
+
+
+@pytest.fixture
+def markdown_content_without_code():
+ return """
+ # This is a markdown document
+
+ There is no code in this document.
+ """
+
+
+def test_extract_code_from_markdown_with_code(
+ markdown_content_with_code,
+):
+ extracted_code = extract_code_from_markdown(
+ markdown_content_with_code
+ )
+ assert "def my_func():" in extracted_code
+ assert 'print("This is my function.")' in extracted_code
+ assert "class MyClass:" in extracted_code
+ assert "pass" in extracted_code
+
+
+def test_extract_code_from_markdown_without_code(
+ markdown_content_without_code,
+):
+ extracted_code = extract_code_from_markdown(
+ markdown_content_without_code
+ )
+ assert extracted_code == ""
+
+
+def test_extract_code_from_markdown_exception():
+ with pytest.raises(TypeError):
+ extract_code_from_markdown(None)
diff --git a/tests/utils/test_find_image_path.py b/tests/utils/test_find_image_path.py
new file mode 100644
index 00000000..9fbc09ee
--- /dev/null
+++ b/tests/utils/test_find_image_path.py
@@ -0,0 +1,52 @@
+# Filename: test_utils.py
+
+import pytest
+from swarms.utils import find_image_path
+import os
+
+
+def test_find_image_path_no_images():
+ assert (
+ find_image_path(
+ "This is a test string without any image paths."
+ )
+ is None
+ )
+
+
+def test_find_image_path_one_image():
+ text = "This is a string with one image path: sample_image.jpg."
+ assert find_image_path(text) == "sample_image.jpg"
+
+
+def test_find_image_path_multiple_images():
+ text = "This string has two image paths: img1.png, and img2.jpg."
+ assert (
+ find_image_path(text) == "img2.jpg"
+ ) # Assuming both images exist
+
+
+def test_find_image_path_wrong_input():
+ with pytest.raises(TypeError):
+ find_image_path(123)
+
+
+@pytest.mark.parametrize(
+ "text, expected",
+ [
+ ("no image path here", None),
+ ("image: sample.png", "sample.png"),
+ ("image: sample.png, another: another.jpeg", "another.jpeg"),
+ ],
+)
+def test_find_image_path_parameterized(text, expected):
+ assert find_image_path(text) == expected
+
+
+def mock_os_path_exists(path):
+ return True
+
+
+def test_find_image_path_mocking(monkeypatch):
+ monkeypatch.setattr(os.path, "exists", mock_os_path_exists)
+ assert find_image_path("image.jpg") == "image.jpg"
diff --git a/tests/utils/test_limit_tokens_from_string.py b/tests/utils/test_limit_tokens_from_string.py
new file mode 100644
index 00000000..5b5f8efd
--- /dev/null
+++ b/tests/utils/test_limit_tokens_from_string.py
@@ -0,0 +1,45 @@
+import pytest
+from swarms.utils import limit_tokens_from_string
+
+
+def test_limit_tokens_from_string():
+ sentence = (
+ "This is a test sentence. It is used for testing the number"
+ " of tokens."
+ )
+ limited = limit_tokens_from_string(sentence, limit=5)
+ assert (
+ len(limited.split()) <= 5
+ ), "The output string has more than 5 tokens."
+
+
+def test_limit_zero_tokens():
+ sentence = "Expect empty result when limit is set to zero."
+ limited = limit_tokens_from_string(sentence, limit=0)
+ assert limited == "", "The output is not empty."
+
+
+def test_negative_token_limit():
+ sentence = (
+ "This test will raise an exception when limit is negative."
+ )
+ with pytest.raises(Exception):
+ limit_tokens_from_string(sentence, limit=-1)
+
+
+@pytest.mark.parametrize(
+ "sentence, model", [("Some sentence", "unavailable-model")]
+)
+def test_unknown_model(sentence, model):
+ with pytest.raises(Exception):
+ limit_tokens_from_string(sentence, model=model)
+
+
+def test_string_token_limit_exceeded():
+ sentence = (
+ "This is a long sentence with more than twenty tokens which"
+ " is used for testing. It checks whether the function"
+ " correctly limits the tokens to a specified amount."
+ )
+ limited = limit_tokens_from_string(sentence, limit=20)
+ assert len(limited.split()) <= 20, "The token limit is exceeded."
diff --git a/tests/utils/test_load_model_torch.py b/tests/utils/test_load_model_torch.py
new file mode 100644
index 00000000..ef2c17d4
--- /dev/null
+++ b/tests/utils/test_load_model_torch.py
@@ -0,0 +1,111 @@
+import pytest
+import torch
+from torch import nn
+from swarms.utils import load_model_torch
+
+
+class DummyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc = nn.Linear(10, 2)
+
+ def forward(self, x):
+ return self.fc(x)
+
+
+# Test case 1: Test if model can be loaded successfully
+def test_load_model_torch_success(tmp_path):
+ model = DummyModel()
+ # Save the model to a temporary directory
+ model_path = tmp_path / "model.pt"
+ torch.save(model.state_dict(), model_path)
+
+ # Load the model
+ model_loaded = load_model_torch(model_path, model=DummyModel())
+
+ # Check if loaded model has the same architecture
+ assert isinstance(
+ model_loaded, DummyModel
+ ), "Loaded model type mismatch."
+
+
+# Test case 2: Test if function raises FileNotFoundError for non-existent file
+def test_load_model_torch_file_not_found():
+ with pytest.raises(FileNotFoundError):
+ load_model_torch("non_existent_model.pt")
+
+
+# Test case 3: Test if function catches and raises RuntimeError for invalid model file
+def test_load_model_torch_invalid_file(tmp_path):
+ file = tmp_path / "invalid_model.pt"
+ file.write_text("Invalid model file.")
+
+ with pytest.raises(RuntimeError):
+ load_model_torch(file)
+
+
+# Test case 4: Test for handling of 'strict' parameter
+def test_load_model_torch_strict_handling(tmp_path):
+ # Create a model and modify it to cause a mismatch
+ model = DummyModel()
+ model.fc = nn.Linear(10, 3)
+ model_path = tmp_path / "model.pt"
+ torch.save(model.state_dict(), model_path)
+
+ # Try to load the modified model with 'strict' parameter set to True
+ with pytest.raises(RuntimeError):
+ load_model_torch(model_path, model=DummyModel(), strict=True)
+
+
+# Test case 5: Test for 'device' parameter handling
+def test_load_model_torch_device_handling(tmp_path):
+ model = DummyModel()
+ model_path = tmp_path / "model.pt"
+ torch.save(model.state_dict(), model_path)
+
+ # Define a device other than default and load the model to the specified device
+ device = torch.device("cpu")
+ model_loaded = load_model_torch(
+ model_path, model=DummyModel(), device=device
+ )
+
+ assert (
+ model_loaded.fc.weight.device == device
+ ), "Model not loaded to specified device."
+
+
+# Test case 6: Testing for correct handling of '*args' and '**kwargs'
+def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path):
+ model = DummyModel()
+ model_path = tmp_path / "model.pt"
+ torch.save(model.state_dict(), model_path)
+
+ def mock_torch_load(*args, **kwargs):
+ assert (
+ "pickle_module" in kwargs
+ ), "Keyword arguments not passed to 'torch.load'."
+
+ # Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly
+ monkeypatch.setattr(torch, "load", mock_torch_load)
+ load_model_torch(
+ model_path, model=DummyModel(), pickle_module="dummy_module"
+ )
+
+
+# Test case 7: Test for model loading on CPU if no GPU is available
+def test_load_model_torch_cpu(tmp_path):
+ model = DummyModel()
+ model_path = tmp_path / "model.pt"
+ torch.save(model.state_dict(), model_path)
+
+ def mock_torch_cuda_is_available():
+ return False
+
+ # Monkeypatch to simulate no GPU available
+ pytest.MonkeyPatch.setattr(
+ torch.cuda, "is_available", mock_torch_cuda_is_available
+ )
+ model_loaded = load_model_torch(model_path, model=DummyModel())
+
+ # Ensure model is loaded on CPU
+ assert next(model_loaded.parameters()).device.type == "cpu"
diff --git a/tests/utils/test_math_eval.py b/tests/utils/test_math_eval.py
index 91013ae3..ae7ee04c 100644
--- a/tests/utils/test_math_eval.py
+++ b/tests/utils/test_math_eval.py
@@ -1,89 +1,41 @@
-import pytest
-from swarms.utils.math_eval import math_eval
+from swarms.utils import math_eval
-def test_math_eval_same_output():
- @math_eval(lambda x: x + 1, lambda x: x + 1)
- def func(x):
- return x
-
- for i in range(20):
- result1, result2 = func(i)
- assert result1 == result2
- assert result1 == i + 1
+def func1_no_exception(x):
+ return x + 2
-def test_math_eval_different_output():
- @math_eval(lambda x: x + 1, lambda x: x + 2)
- def func(x):
- return x
+def func2_no_exception(x):
+ return x + 2
- for i in range(20):
- result1, result2 = func(i)
- assert result1 != result2
- assert result1 == i + 1
- assert result2 == i + 2
+def func1_with_exception(x):
+ raise ValueError()
-def test_math_eval_exception_in_func1():
- @math_eval(lambda x: 1 / x, lambda x: x)
- def func(x):
- return x
- with pytest.raises(ZeroDivisionError):
- func(0)
+def func2_with_exception(x):
+ raise ValueError()
-def test_math_eval_exception_in_func2():
- @math_eval(lambda x: x, lambda x: 1 / x)
- def func(x):
+def test_same_results_no_exception(caplog):
+ @math_eval(func1_no_exception, func2_no_exception)
+ def test_func(x):
return x
- with pytest.raises(ZeroDivisionError):
- func(0)
-
-
-def test_math_eval_with_multiple_arguments():
- @math_eval(lambda x, y: x + y, lambda x, y: y + x)
- def func(x, y):
- return x, y
-
- for i in range(10):
- for j in range(10):
- result1, result2 = func(i, j)
- assert result1 == result2
- assert result1 == i + j
+ result1, result2 = test_func(5)
+ assert result1 == result2 == 7
+ assert "Outputs do not match" not in caplog.text
-def test_math_eval_with_kwargs():
- @math_eval(lambda x, y=0: x + y, lambda x, y=0: y + x)
- def func(x, y=0):
- return x, y
-
- for i in range(10):
- for j in range(10):
- result1, result2 = func(i, y=j)
- assert result1 == result2
- assert result1 == i + j
-
-
-def test_math_eval_with_no_arguments():
- @math_eval(lambda: 1, lambda: 1)
- def func():
- return
-
- result1, result2 = func()
- assert result1 == result2
- assert result1 == 1
+def test_func1_exception(caplog):
+ @math_eval(func1_with_exception, func2_no_exception)
+ def test_func(x):
+ return x
+ result1, result2 = test_func(5)
+ assert result1 is None
+ assert result2 == 7
+ assert "Error in func1:" in caplog.text
-def test_math_eval_with_different_types():
- @math_eval(lambda x: str(x), lambda x: x)
- def func(x):
- return x
- for i in range(10):
- result1, result2 = func(i)
- assert result1 != result2
- assert result1 == str(i)
- assert result2 == i
+# similar tests for func2_with_exception and when func1 and func2 return different results
diff --git a/tests/utils/test_metrics_decorator.py b/tests/utils/test_metrics_decorator.py
new file mode 100644
index 00000000..7a676657
--- /dev/null
+++ b/tests/utils/test_metrics_decorator.py
@@ -0,0 +1,84 @@
+# pytest imports
+import pytest
+from unittest.mock import Mock
+
+# Imports from your project
+from swarms.utils import metrics_decorator
+import time
+
+
+# Basic successful test
+def test_metrics_decorator_success():
+ @metrics_decorator
+ def decorated_func():
+ time.sleep(0.1)
+ return [1, 2, 3, 4, 5]
+
+ metrics = decorated_func()
+ assert "Time to First Token" in metrics
+ assert "Generation Latency" in metrics
+ assert "Throughput:" in metrics
+
+
+@pytest.mark.parametrize(
+ "wait_time, return_val",
+ [
+ (0, []),
+ (0.1, [1, 2, 3]),
+ (0.5, list(range(50))),
+ ],
+)
+def test_metrics_decorator_with_various_wait_times_and_return_vals(
+ wait_time, return_val
+):
+ @metrics_decorator
+ def decorated_func():
+ time.sleep(wait_time)
+ return return_val
+
+ metrics = decorated_func()
+ assert "Time to First Token" in metrics
+ assert "Generation Latency" in metrics
+ assert "Throughput:" in metrics
+
+
+# Test to ensure that mocked time function was called and throughputs are calculated as expected
+def test_metrics_decorator_with_mocked_time(mocker):
+ mocked_time = Mock()
+ mocker.patch("time.time", mocked_time)
+
+ mocked_time.side_effect = [0, 5, 10, 20]
+
+ @metrics_decorator
+ def decorated_func():
+ return ["tok_1", "tok_2"]
+
+ metrics = decorated_func()
+ assert metrics == """
+ Time to First Token: 5
+ Generation Latency: 20
+ Throughput: 0.1
+ """
+ mocked_time.assert_any_call()
+
+
+# Test to ensure that exceptions in the decorated function are propagated
+def test_metrics_decorator_raises_exception():
+ @metrics_decorator
+ def decorated_func():
+ raise ValueError("Oops!")
+
+ with pytest.raises(ValueError, match="Oops!"):
+ decorated_func()
+
+
+# Test to ensure proper handling when decorated function returns non-list value
+def test_metrics_decorator_with_non_list_return_val():
+ @metrics_decorator
+ def decorated_func():
+ return "Hello, world!"
+
+ metrics = decorated_func()
+ assert "Time to First Token" in metrics
+ assert "Generation Latency" in metrics
+ assert "Throughput:" in metrics
diff --git a/tests/utils/test_pdf_to_text.py b/tests/utils/test_pdf_to_text.py
new file mode 100644
index 00000000..57e3b33f
--- /dev/null
+++ b/tests/utils/test_pdf_to_text.py
@@ -0,0 +1,40 @@
+import pytest
+import PyPDF2
+from swarms.utils import pdf_to_text
+
+
+@pytest.fixture
+def pdf_file(tmpdir):
+ pdf_writer = PyPDF2.PdfWriter()
+ pdf_page = PyPDF2.pdf.PageObject.createBlankPage(None, 200, 200)
+ pdf_writer.add_page(pdf_page)
+ pdf_file = tmpdir.join("temp.pdf")
+ with open(pdf_file, "wb") as output:
+ pdf_writer.write(output)
+ return str(pdf_file)
+
+
+def test_valid_pdf_to_text(pdf_file):
+ result = pdf_to_text(pdf_file)
+ assert isinstance(result, str)
+
+
+def test_non_existing_file():
+ with pytest.raises(FileNotFoundError):
+ pdf_to_text("non_existing_file.pdf")
+
+
+def test_passing_non_pdf_file(tmpdir):
+ file = tmpdir.join("temp.txt")
+ file.write("This is a test")
+ with pytest.raises(
+ Exception,
+ match=r"An error occurred while reading the PDF file",
+ ):
+ pdf_to_text(str(file))
+
+
+@pytest.mark.parametrize("invalid_pdf_file", [None, 123, {}, []])
+def test_invalid_pdf_to_text(invalid_pdf_file):
+ with pytest.raises(Exception):
+ pdf_to_text(invalid_pdf_file)
diff --git a/tests/utils/test_prep_torch_inference.py b/tests/utils/test_prep_torch_inference.py
new file mode 100644
index 00000000..8ee33fbc
--- /dev/null
+++ b/tests/utils/test_prep_torch_inference.py
@@ -0,0 +1,49 @@
+import unittest
+import pytest
+import torch
+from unittest.mock import Mock
+from swarms.utils import prep_torch_inference
+
+
+def test_prep_torch_inference():
+ model_path = "model_path"
+ device = torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu"
+ )
+ model_mock = Mock()
+ model_mock.eval = Mock()
+
+ # Mocking the load_model_torch function to return our mock model.
+ with unittest.mock.patch(
+ "swarms.utils.load_model_torch", return_value=model_mock
+ ) as _:
+ model = prep_torch_inference(model_path, device)
+
+ # Check if model was properly loaded and eval function was called
+ assert model == model_mock
+ model_mock.eval.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "model_path, device",
+ [
+ (
+ "invalid_path",
+ torch.device("cuda"),
+ ), # Invalid file path, valid device
+ (None, torch.device("cuda")), # None file path, valid device
+ ("model_path", None), # Valid file path, None device
+ (None, None), # None file path, None device
+ ],
+)
+def test_prep_torch_inference_exceptions(model_path, device):
+ with pytest.raises(Exception):
+ prep_torch_inference(model_path, device)
+
+
+def test_prep_torch_inference_return_none():
+ model_path = "invalid_path" # Invalid file path
+ device = torch.device("cuda") # Valid device
+
+ # Since load_model_torch function will raise an exception, prep_torch_inference should return None
+ assert prep_torch_inference(model_path, device) is None
diff --git a/tests/utils/test_print_class_parameters.py b/tests/utils/test_print_class_parameters.py
new file mode 100644
index 00000000..ae824170
--- /dev/null
+++ b/tests/utils/test_print_class_parameters.py
@@ -0,0 +1,119 @@
+import pytest
+from swarms.utils import print_class_parameters
+
+
+class TestObject:
+ def __init__(self, value1, value2: int):
+ pass
+
+
+class TestObject2:
+ def __init__(self: "TestObject2", value1, value2: int = 5):
+ pass
+
+
+def test_class_with_complex_parameters():
+ class ComplexArgs:
+ def __init__(self, value1: list, value2: dict = {}):
+ pass
+
+ output = {"value1": "", "value2": ""}
+ assert (
+ print_class_parameters(ComplexArgs, api_format=True) == output
+ )
+
+
+def test_empty_class():
+ class Empty:
+ pass
+
+ with pytest.raises(Exception):
+ print_class_parameters(Empty)
+
+
+def test_class_with_no_annotations():
+ class NoAnnotations:
+ def __init__(self, value1, value2):
+ pass
+
+ output = {
+ "value1": "",
+ "value2": "",
+ }
+ assert (
+ print_class_parameters(NoAnnotations, api_format=True)
+ == output
+ )
+
+
+def test_class_with_partial_annotations():
+ class PartialAnnotations:
+ def __init__(self, value1, value2: int):
+ pass
+
+ output = {
+ "value1": "",
+ "value2": "",
+ }
+ assert (
+ print_class_parameters(PartialAnnotations, api_format=True)
+ == output
+ )
+
+
+@pytest.mark.parametrize(
+ "obj, expected",
+ [
+ (
+ TestObject,
+ {
+ "value1": "",
+ "value2": "",
+ },
+ ),
+ (
+ TestObject2,
+ {
+ "value1": "",
+ "value2": "",
+ },
+ ),
+ ],
+)
+def test_parametrized_class_parameters(obj, expected):
+ assert print_class_parameters(obj, api_format=True) == expected
+
+
+@pytest.mark.parametrize(
+ "value",
+ [
+ int,
+ float,
+ str,
+ list,
+ set,
+ dict,
+ bool,
+ tuple,
+ complex,
+ bytes,
+ bytearray,
+ memoryview,
+ range,
+ frozenset,
+ slice,
+ object,
+ ],
+)
+def test_not_class_exception(value):
+ with pytest.raises(Exception):
+ print_class_parameters(value)
+
+
+def test_api_format_flag():
+ assert print_class_parameters(TestObject2, api_format=True) == {
+ "value1": "",
+ "value2": "",
+ }
+ print_class_parameters(TestObject)
+ # TODO: Capture printed output and assert correctness.
diff --git a/tests/utils/test_subprocess_code_interpreter.py b/tests/utils/test_subprocess_code_interpreter.py
index 2c7f7e47..3ce54530 100644
--- a/tests/utils/test_subprocess_code_interpreter.py
+++ b/tests/utils/test_subprocess_code_interpreter.py
@@ -1,307 +1,77 @@
+import pytest
import subprocess
import threading
-import time
-
-import pytest
-
+import queue
from swarms.utils.code_interpreter import (
- BaseCodeInterpreter,
SubprocessCodeInterpreter,
-)
+) # Adjust the import according to your project structure
+# Fixture for the SubprocessCodeInterpreter instance
@pytest.fixture
-def subprocess_code_interpreter():
- interpreter = SubprocessCodeInterpreter()
- interpreter.start_cmd = "python -c"
- yield interpreter
- interpreter.terminate()
-
-
-def test_base_code_interpreter_init():
- interpreter = BaseCodeInterpreter()
- assert isinstance(interpreter, BaseCodeInterpreter)
-
-
-def test_base_code_interpreter_run_not_implemented():
- interpreter = BaseCodeInterpreter()
- with pytest.raises(NotImplementedError):
- interpreter.run("code")
-
-
-def test_base_code_interpreter_terminate_not_implemented():
- interpreter = BaseCodeInterpreter()
- with pytest.raises(NotImplementedError):
- interpreter.terminate()
-
-
-def test_subprocess_code_interpreter_init(
- subprocess_code_interpreter,
-):
- assert isinstance(
- subprocess_code_interpreter, SubprocessCodeInterpreter
- )
-
-
-def test_subprocess_code_interpreter_start_process(
- subprocess_code_interpreter,
-):
- subprocess_code_interpreter.start_process()
- assert subprocess_code_interpreter.process is not None
-
-
-def test_subprocess_code_interpreter_terminate(
- subprocess_code_interpreter,
-):
- subprocess_code_interpreter.start_process()
- subprocess_code_interpreter.terminate()
- assert subprocess_code_interpreter.process.poll() is not None
-
-
-def test_subprocess_code_interpreter_run_success(
- subprocess_code_interpreter,
-):
- code = 'print("Hello, World!")'
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "Hello, World!" in output.get("output", "")
- for output in result
- )
-
-
-def test_subprocess_code_interpreter_run_with_error(
- subprocess_code_interpreter,
-):
- code = 'print("Hello, World")\nraise ValueError("Error!")'
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "Error!" in output.get("output", "") for output in result
- )
-
-
-def test_subprocess_code_interpreter_run_with_keyboard_interrupt(
- subprocess_code_interpreter,
-):
- code = (
- 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise'
- " KeyboardInterrupt"
- )
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "KeyboardInterrupt" in output.get("output", "")
- for output in result
- )
-
-
-def test_subprocess_code_interpreter_run_max_retries(
- subprocess_code_interpreter, monkeypatch
-):
- def mock_subprocess_popen(*args, **kwargs):
- raise subprocess.CalledProcessError(1, "mocked_cmd")
-
- monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen)
+def interpreter():
+ return SubprocessCodeInterpreter()
- code = 'print("Hello, World!")'
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "Maximum retries reached. Could not execute code."
- in output.get("output", "")
- for output in result
- )
+# Test for correct initialization
+def test_initialization(interpreter):
+ assert interpreter.start_cmd == ""
+ assert interpreter.process is None
+ assert not interpreter.debug_mode
+ assert isinstance(interpreter.output_queue, queue.Queue)
+ assert isinstance(interpreter.done, threading.Event)
-def test_subprocess_code_interpreter_run_retry_on_error(
- subprocess_code_interpreter, monkeypatch
-):
- def mock_subprocess_popen(*args, **kwargs):
- nonlocal popen_count
- if popen_count == 0:
- popen_count += 1
- raise subprocess.CalledProcessError(1, "mocked_cmd")
- else:
- return subprocess.Popen(
- "echo 'Hello, World!'",
- shell=True,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
-
- monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen)
- popen_count = 0
-
- code = 'print("Hello, World!")'
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "Hello, World!" in output.get("output", "")
- for output in result
- )
+# Test for starting and terminating process
+def test_start_and_terminate_process(interpreter):
+ interpreter.start_cmd = "echo Hello"
+ interpreter.start_process()
+ assert isinstance(interpreter.process, subprocess.Popen)
+ interpreter.terminate()
+ assert (
+ interpreter.process.poll() is not None
+ ) # Process should be terminated
-# Add more tests to cover other aspects of the code and edge cases as needed
-# Import statements and fixtures from the previous code block
+# Test preprocess_code method
+def test_preprocess_code(interpreter):
+ code = "print('Hello, World!')"
+ processed_code = interpreter.preprocess_code(code)
+ # Add assertions based on expected behavior of preprocess_code
+ assert processed_code == code # Example assertion
-def test_subprocess_code_interpreter_line_postprocessor(
- subprocess_code_interpreter,
-):
- line = "This is a test line"
- processed_line = subprocess_code_interpreter.line_postprocessor(
- line
- )
+# Test detect_active_line method
+def test_detect_active_line(interpreter):
+ line = "Some line of code"
assert (
- processed_line == line
- ) # No processing, should remain the same
+ interpreter.detect_active_line(line) is None
+ ) # Adjust assertion based on implementation
-def test_subprocess_code_interpreter_preprocess_code(
- subprocess_code_interpreter,
-):
- code = 'print("Hello, World!")'
- preprocessed_code = subprocess_code_interpreter.preprocess_code(
- code
- )
+# Test detect_end_of_execution method
+def test_detect_end_of_execution(interpreter):
+ line = "End of execution line"
assert (
- preprocessed_code == code
- ) # No preprocessing, should remain the same
-
-
-def test_subprocess_code_interpreter_detect_active_line(
- subprocess_code_interpreter,
-):
- line = "Active line: 5"
- active_line = subprocess_code_interpreter.detect_active_line(line)
- assert active_line == 5
-
-
-def test_subprocess_code_interpreter_detect_end_of_execution(
- subprocess_code_interpreter,
-):
- line = "Execution completed."
- end_of_execution = (
- subprocess_code_interpreter.detect_end_of_execution(line)
- )
- assert end_of_execution is True
-
-
-def test_subprocess_code_interpreter_run_debug_mode(
- subprocess_code_interpreter, capsys
-):
- subprocess_code_interpreter.debug_mode = True
- code = 'print("Hello, World!")'
- list(subprocess_code_interpreter.run(code))
- captured = capsys.readouterr()
- assert "Running code:\n" in captured.out
- assert "Received output line:\n" in captured.out
-
-
-def test_subprocess_code_interpreter_run_no_debug_mode(
- subprocess_code_interpreter, capsys
-):
- subprocess_code_interpreter.debug_mode = False
- code = 'print("Hello, World!")'
- list(subprocess_code_interpreter.run(code))
- captured = capsys.readouterr()
- assert "Running code:\n" not in captured.out
- assert "Received output line:\n" not in captured.out
-
+ interpreter.detect_end_of_execution(line) is None
+ ) # Adjust assertion based on implementation
-def test_subprocess_code_interpreter_run_empty_output_queue(
- subprocess_code_interpreter,
-):
- code = 'print("Hello, World!")'
- result = list(subprocess_code_interpreter.run(code))
- assert not any("active_line" in output for output in result)
-
-
-def test_subprocess_code_interpreter_handle_stream_output_stdout(
- subprocess_code_interpreter,
-):
- line = "This is a test line"
- subprocess_code_interpreter.handle_stream_output(
- threading.current_thread(), False
- )
- subprocess_code_interpreter.process.stdout.write(line + "\n")
- subprocess_code_interpreter.process.stdout.flush()
- time.sleep(0.1)
- output = subprocess_code_interpreter.output_queue.get()
- assert output["output"] == line
-
-
-def test_subprocess_code_interpreter_handle_stream_output_stderr(
- subprocess_code_interpreter,
-):
- line = "This is an error line"
- subprocess_code_interpreter.handle_stream_output(
- threading.current_thread(), True
- )
- subprocess_code_interpreter.process.stderr.write(line + "\n")
- subprocess_code_interpreter.process.stderr.flush()
- time.sleep(0.1)
- output = subprocess_code_interpreter.output_queue.get()
- assert output["output"] == line
-
-
-def test_subprocess_code_interpreter_run_with_preprocess_code(
- subprocess_code_interpreter, capsys
-):
- code = 'print("Hello, World!")'
- subprocess_code_interpreter.preprocess_code = (
- lambda x: x.upper()
- ) # Modify code in preprocess_code
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "Hello, World!" in output.get("output", "")
- for output in result
- )
-
-
-def test_subprocess_code_interpreter_run_with_exception(
- subprocess_code_interpreter, capsys
-):
- code = 'print("Hello, World!")'
- subprocess_code_interpreter.start_cmd = ( # Force an exception during subprocess creation
- "nonexistent_command"
- )
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "Maximum retries reached" in output.get("output", "")
- for output in result
- )
-
-
-def test_subprocess_code_interpreter_run_with_active_line(
- subprocess_code_interpreter, capsys
-):
- code = "a = 5\nprint(a)" # Contains an active line
- result = list(subprocess_code_interpreter.run(code))
- assert any(output.get("active_line") == 5 for output in result)
-
-
-def test_subprocess_code_interpreter_run_with_end_of_execution(
- subprocess_code_interpreter, capsys
-):
- code = ( # Simple code without active line marker
- 'print("Hello, World!")'
- )
- result = list(subprocess_code_interpreter.run(code))
- assert any(output.get("active_line") is None for output in result)
+# Test line_postprocessor method
+def test_line_postprocessor(interpreter):
+ line = "Some output line"
+ assert (
+ interpreter.line_postprocessor(line) == line
+ ) # Adjust assertion based on implementation
-def test_subprocess_code_interpreter_run_with_multiple_lines(
- subprocess_code_interpreter, capsys
-):
- code = "a = 5\nb = 10\nprint(a + b)"
- result = list(subprocess_code_interpreter.run(code))
- assert any("15" in output.get("output", "") for output in result)
+# Test handle_stream_output method
+def test_handle_stream_output(interpreter, monkeypatch):
+ # This requires more complex setup, including monkeypatching and simulating stream output
+ # Example setup
+ def mock_readline():
+ yield "output line"
+ yield ""
-def test_subprocess_code_interpreter_run_with_unicode_characters(
- subprocess_code_interpreter, capsys
-):
- code = 'print("γγγ«γ‘γ―γδΈη")' # Contains unicode characters
- result = list(subprocess_code_interpreter.run(code))
- assert any(
- "γγγ«γ‘γ―γδΈη" in output.get("output", "")
- for output in result
- )
+ monkeypatch.setattr("sys.stdout", mock_readline())
+ # More test code needed here to simulate and assert the behavior of handle_stream_output