commit
3e29a067df
@ -0,0 +1,180 @@
|
|||||||
|
## `Gemini` Documentation
|
||||||
|
|
||||||
|
### Introduction
|
||||||
|
|
||||||
|
The Gemini module is a versatile tool for leveraging the power of multimodal AI models to generate content. It allows users to combine textual and image inputs to generate creative and informative outputs. In this documentation, we will explore the Gemini module in detail, covering its purpose, architecture, methods, and usage examples.
|
||||||
|
|
||||||
|
#### Purpose
|
||||||
|
|
||||||
|
The Gemini module is designed to bridge the gap between text and image data, enabling users to harness the capabilities of multimodal AI models effectively. By providing both a textual task and an image as input, Gemini generates content that aligns with the specified task and incorporates the visual information from the image.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Before using Gemini, ensure that you have the required dependencies installed. You can install them using the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install swarms
|
||||||
|
pip install google-generativeai
|
||||||
|
pip install python-dotenv
|
||||||
|
```
|
||||||
|
|
||||||
|
### Class: Gemini
|
||||||
|
|
||||||
|
#### Overview
|
||||||
|
|
||||||
|
The `Gemini` class is the central component of the Gemini module. It inherits from the `BaseMultiModalModel` class and provides methods to interact with the Gemini AI model. Let's dive into its architecture and functionality.
|
||||||
|
|
||||||
|
##### Class Constructor
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Gemini(BaseMultiModalModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gemini-pro",
|
||||||
|
gemini_api_key: str = get_gemini_api_key_env,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
```
|
||||||
|
|
||||||
|
| Parameter | Type | Description | Default Value |
|
||||||
|
|---------------------|---------|------------------------------------------------------------------|--------------------|
|
||||||
|
| `model_name` | str | The name of the Gemini model. | "gemini-pro" |
|
||||||
|
| `gemini_api_key` | str | The Gemini API key. If not provided, it is fetched from the environment. | (None) |
|
||||||
|
|
||||||
|
- `model_name`: Specifies the name of the Gemini model to use. By default, it is set to "gemini-pro," but you can specify a different model if needed.
|
||||||
|
|
||||||
|
- `gemini_api_key`: This parameter allows you to provide your Gemini API key directly. If not provided, the constructor attempts to fetch it from the environment using the `get_gemini_api_key_env` helper function.
|
||||||
|
|
||||||
|
##### Methods
|
||||||
|
|
||||||
|
1. **run()**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
task: str = None,
|
||||||
|
img: str = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
```
|
||||||
|
|
||||||
|
| Parameter | Type | Description |
|
||||||
|
|---------------|----------|--------------------------------------------|
|
||||||
|
| `task` | str | The textual task for content generation. |
|
||||||
|
| `img` | str | The path to the image to be processed. |
|
||||||
|
| `*args` | Variable | Additional positional arguments. |
|
||||||
|
| `**kwargs` | Variable | Additional keyword arguments. |
|
||||||
|
|
||||||
|
- `task`: Specifies the textual task for content generation. It can be a sentence or a phrase that describes the desired content.
|
||||||
|
|
||||||
|
- `img`: Provides the path to the image that will be processed along with the textual task. Gemini combines the visual information from the image with the textual task to generate content.
|
||||||
|
|
||||||
|
- `*args` and `**kwargs`: Allow for additional, flexible arguments that can be passed to the underlying Gemini model. These arguments can vary based on the specific Gemini model being used.
|
||||||
|
|
||||||
|
**Returns**: A string containing the generated content.
|
||||||
|
|
||||||
|
**Examples**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from swarms.models import Gemini
|
||||||
|
|
||||||
|
# Initialize the Gemini model
|
||||||
|
gemini = Gemini()
|
||||||
|
|
||||||
|
# Generate content for a textual task with an image
|
||||||
|
generated_content = gemini.run(
|
||||||
|
task="Describe this image",
|
||||||
|
img="image.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print the generated content
|
||||||
|
print(generated_content)
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we initialize the Gemini model, provide a textual task, and specify an image for processing. The `run()` method generates content based on the input and returns the result.
|
||||||
|
|
||||||
|
2. **process_img()**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def process_img(
|
||||||
|
self,
|
||||||
|
img: str = None,
|
||||||
|
type: str = "image/png",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
```
|
||||||
|
|
||||||
|
| Parameter | Type | Description | Default Value |
|
||||||
|
|---------------|----------|------------------------------------------------------|----------------|
|
||||||
|
| `img` | str | The path to the image to be processed. | (None) |
|
||||||
|
| `type` | str | The MIME type of the image (e.g., "image/png"). | "image/png" |
|
||||||
|
| `*args` | Variable | Additional positional arguments. |
|
||||||
|
| `**kwargs` | Variable | Additional keyword arguments. |
|
||||||
|
|
||||||
|
- `img`: Specifies the path to the image that will be processed. It's essential to provide a valid image path for image-based content generation.
|
||||||
|
|
||||||
|
- `type`: Indicates the MIME type of the image. By default, it is set to "image/png," but you can change it based on the image format you're using.
|
||||||
|
|
||||||
|
- `*args` and `**kwargs`: Allow for additional, flexible arguments that can be passed to the underlying Gemini model. These arguments can vary based on the specific Gemini model being used.
|
||||||
|
|
||||||
|
**Raises**: ValueError if any of the following conditions are met:
|
||||||
|
- No image is provided.
|
||||||
|
- The image type is not specified.
|
||||||
|
- The Gemini API key is missing.
|
||||||
|
|
||||||
|
**Examples**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from swarms.models.gemini import Gemini
|
||||||
|
|
||||||
|
# Initialize the Gemini model
|
||||||
|
gemini = Gemini()
|
||||||
|
|
||||||
|
# Process an image
|
||||||
|
processed_image = gemini.process_img(
|
||||||
|
img="image.jpg",
|
||||||
|
type="image/jpeg",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Further use the processed image in content generation
|
||||||
|
generated_content = gemini.run(
|
||||||
|
task="Describe this image",
|
||||||
|
img=processed_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print the generated content
|
||||||
|
print(generated_content)
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we demonstrate how to process an image using the `process_img()` method and then use the processed image in content generation.
|
||||||
|
|
||||||
|
#### Additional Information
|
||||||
|
|
||||||
|
- Gemini is designed to work seamlessly with various multimodal AI models, making it a powerful tool for content generation tasks.
|
||||||
|
|
||||||
|
- The module uses the `google.generativeai` package to access the underlying AI models. Ensure that you have this package installed to leverage the full capabilities of Gemini.
|
||||||
|
|
||||||
|
- It's essential to provide a valid Gemini API key for authentication. You can either pass it directly during initialization or store it in the environment variable "GEMINI_API_KEY."
|
||||||
|
|
||||||
|
- Gemini's flexibility allows you to experiment with different Gemini models and tailor the content generation process to your specific needs.
|
||||||
|
|
||||||
|
- Keep in mind that Gemini is designed to handle both textual and image inputs, making it a valuable asset for various applications, including natural language processing and computer vision tasks.
|
||||||
|
|
||||||
|
- If you encounter any issues or have specific requirements, refer to the Gemini documentation for more details and advanced usage.
|
||||||
|
|
||||||
|
### References and Resources
|
||||||
|
|
||||||
|
- [Gemini GitHub Repository](https://github.com/swarms/gemini): Explore the Gemini repository for additional information, updates, and examples.
|
||||||
|
|
||||||
|
- [Google GenerativeAI
|
||||||
|
|
||||||
|
Documentation](https://docs.google.com/document/d/1WZSBw6GsOhOCYm0ArydD_9uy6nPPA1KFIbKPhjj43hA): Dive deeper into the capabilities of the Google GenerativeAI package used by Gemini.
|
||||||
|
|
||||||
|
- [Gemini API Documentation](https://gemini-api-docs.example.com): Access the official documentation for the Gemini API to explore advanced features and integrations.
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
In this comprehensive documentation, we've explored the Gemini module, its purpose, architecture, methods, and usage examples. Gemini empowers developers to generate content by combining textual tasks and images, making it a valuable asset for multimodal AI applications. Whether you're working on natural language processing or computer vision projects, Gemini can help you achieve impressive results.
|
After Width: | Height: | Size: 74 KiB |
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models.gpt4_vision_api import GPT4VisionAPI
|
||||||
|
from swarms.prompts.visual_cot import VISUAL_CHAIN_OF_THOUGHT
|
||||||
|
from swarms.structs import Agent
|
||||||
|
|
||||||
|
# Load the environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get the API key from the environment
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
# Initialize the language model
|
||||||
|
llm = GPT4VisionAPI(
|
||||||
|
openai_api_key=api_key,
|
||||||
|
max_tokens=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the task
|
||||||
|
task = "This is an eye test. What do you see?"
|
||||||
|
img = "playground/demos/multi_modal_chain_of_thought/eyetest.jpg"
|
||||||
|
|
||||||
|
## Initialize the workflow
|
||||||
|
agent = Agent(
|
||||||
|
llm=llm,
|
||||||
|
max_loops=2,
|
||||||
|
autosave=True,
|
||||||
|
sop=VISUAL_CHAIN_OF_THOUGHT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the workflow on a task
|
||||||
|
out = agent.run(task=task, img=img)
|
||||||
|
print(out)
|
@ -0,0 +1,160 @@
|
|||||||
|
import os
|
||||||
|
import subprocess as sp
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
except ImportError as error:
|
||||||
|
print(f"Error importing google.generativeai: {error}")
|
||||||
|
print("Please install the google.generativeai package")
|
||||||
|
print("pip install google-generativeai")
|
||||||
|
sp.run(["pip", "install", "--upgrade", "google-generativeai"])
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
def get_gemini_api_key_env():
|
||||||
|
"""Get the Gemini API key from the environment
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
key = os.getenv("GEMINI_API_KEY")
|
||||||
|
if key is None:
|
||||||
|
raise ValueError("Please provide a Gemini API key")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
# Main class
|
||||||
|
class Gemini(BaseMultiModalModel):
|
||||||
|
"""Gemini model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
BaseMultiModalModel (class): Base multimodal model class
|
||||||
|
model_name (str, optional): model name. Defaults to "gemini-pro".
|
||||||
|
gemini_api_key (str, optional): Gemini API key. Defaults to None.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run: run the Gemini model
|
||||||
|
process_img: process the image
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from swarms.models import Gemini
|
||||||
|
>>> gemini = Gemini()
|
||||||
|
>>> gemini.run(
|
||||||
|
task="A dog",
|
||||||
|
img="dog.png",
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gemini-pro",
|
||||||
|
gemini_api_key: str = get_gemini_api_key_env,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model_name, *args, **kwargs)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.gemini_api_key = gemini_api_key
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
self.model = genai.GenerativeModel(
|
||||||
|
model_name, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
task: str = None,
|
||||||
|
img: str = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""Run the Gemini model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str, optional): textual task. Defaults to None.
|
||||||
|
img (str, optional): img. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: output from the model
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if img:
|
||||||
|
process_img = self.process_img(img, *args, **kwargs)
|
||||||
|
response = self.model.generate_content(
|
||||||
|
content=[task, process_img], *args, **kwargs
|
||||||
|
)
|
||||||
|
return response.text
|
||||||
|
else:
|
||||||
|
response = self.model.generate_content(
|
||||||
|
task, *args, **kwargs
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error running Gemini model: {error}")
|
||||||
|
|
||||||
|
def process_img(
|
||||||
|
self,
|
||||||
|
img: str = None,
|
||||||
|
type: str = "image/png",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Process the image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str, optional): _description_. Defaults to None.
|
||||||
|
type (str, optional): _description_. Defaults to "image/png".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: _description_
|
||||||
|
ValueError: _description_
|
||||||
|
ValueError: _description_
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if img is None:
|
||||||
|
raise ValueError("Please provide an image to process")
|
||||||
|
if type is None:
|
||||||
|
raise ValueError("Please provide the image type")
|
||||||
|
if self.gemini_api_key is None:
|
||||||
|
raise ValueError("Please provide a Gemini API key")
|
||||||
|
|
||||||
|
# Load the image
|
||||||
|
img = [
|
||||||
|
{"mime_type": type, "data": Path(img).read_bytes()}
|
||||||
|
]
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error processing image: {error}")
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
task: str = None,
|
||||||
|
img: str = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""Chat with the Gemini model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str, optional): _description_. Defaults to None.
|
||||||
|
img (str, optional): _description_. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: _description_
|
||||||
|
"""
|
||||||
|
chat = self.model.start_chat()
|
||||||
|
response = chat.send_message(task, *args, **kwargs)
|
||||||
|
response1 = response.text
|
||||||
|
print(response1)
|
||||||
|
response = chat.send_message(img, *args, **kwargs)
|
@ -0,0 +1,140 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models.base_llm import AbstractLLM
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def together_api_key_env():
|
||||||
|
"""Get the API key from the environment."""
|
||||||
|
return os.getenv("TOGETHER_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherModel(AbstractLLM):
|
||||||
|
"""
|
||||||
|
GPT-4 Vision API
|
||||||
|
|
||||||
|
This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
together_api_key : str
|
||||||
|
The OpenAI API key. Defaults to the together_api_key environment variable.
|
||||||
|
max_tokens : int
|
||||||
|
The maximum number of tokens to generate. Defaults to 300.
|
||||||
|
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
encode_image(img: str)
|
||||||
|
Encode image to base64.
|
||||||
|
run(task: str, img: str)
|
||||||
|
Run the model.
|
||||||
|
__call__(task: str, img: str)
|
||||||
|
Run the model.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
---------
|
||||||
|
>>> from swarms.models import GPT4VisionAPI
|
||||||
|
>>> llm = GPT4VisionAPI()
|
||||||
|
>>> task = "What is the color of the object?"
|
||||||
|
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg"
|
||||||
|
>>> llm.run(task, img)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
together_api_key: str = together_api_key_env,
|
||||||
|
model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
logging_enabled: bool = False,
|
||||||
|
max_workers: int = 10,
|
||||||
|
max_tokens: str = 300,
|
||||||
|
api_endpoint: str = "https://api.together.xyz",
|
||||||
|
beautify: bool = False,
|
||||||
|
streaming_enabled: Optional[bool] = False,
|
||||||
|
meta_prompt: Optional[bool] = False,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(TogetherModel).__init__(*args, **kwargs)
|
||||||
|
self.together_api_key = together_api_key
|
||||||
|
self.logging_enabled = logging_enabled
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.api_endpoint = api_endpoint
|
||||||
|
self.beautify = beautify
|
||||||
|
self.streaming_enabled = streaming_enabled
|
||||||
|
self.meta_prompt = meta_prompt
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
if self.logging_enabled:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
# Disable debug logs for requests and urllib3
|
||||||
|
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
if self.meta_prompt:
|
||||||
|
self.system_prompt = self.meta_prompt_init()
|
||||||
|
|
||||||
|
# Function to handle vision tasks
|
||||||
|
def run(self, task: str = None, *args, **kwargs):
|
||||||
|
"""Run the model."""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.together_api_key}",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [self.system_prompt],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": task,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
self.api_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = response.json()
|
||||||
|
if "choices" in out and out["choices"]:
|
||||||
|
content = (
|
||||||
|
out["choices"][0]
|
||||||
|
.get("message", {})
|
||||||
|
.get("content", None)
|
||||||
|
)
|
||||||
|
if self.streaming_enabled:
|
||||||
|
content = self.stream_response(content)
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
print("No valid response in 'choices'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
print(
|
||||||
|
f"Error with the request: {error}, make sure you"
|
||||||
|
" double check input types and positions"
|
||||||
|
)
|
||||||
|
return None
|
@ -0,0 +1,58 @@
|
|||||||
|
def react_prompt(task: str = None):
|
||||||
|
PROMPT = f"""
|
||||||
|
Task Description:
|
||||||
|
Accomplish the following {task} using the reasoning guidelines below.
|
||||||
|
|
||||||
|
|
||||||
|
######### REASONING GUIDELINES #########
|
||||||
|
You're an autonomous agent that has been tasked with {task}. You have been given a set of guidelines to follow to accomplish this task. You must follow the guidelines exactly.
|
||||||
|
|
||||||
|
Step 1: Observation
|
||||||
|
|
||||||
|
Begin by carefully observing the situation or problem at hand. Describe what you see, identify key elements, and note any relevant details.
|
||||||
|
|
||||||
|
Use <observation>...</observation> tokens to encapsulate your observations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
<observation> [Describe your initial observations of the task or problem here.] </observation>
|
||||||
|
|
||||||
|
Step 2: Thought Process
|
||||||
|
|
||||||
|
Analyze the observations. Consider different angles, potential challenges, and any underlying patterns or connections.
|
||||||
|
|
||||||
|
Think about possible solutions or approaches to address the task.
|
||||||
|
|
||||||
|
Use <thought>...</thought> tokens to encapsulate your thinking process.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
<thought> [Explain your analysis of the observations, your reasoning behind potential solutions, and any assumptions or considerations you are making.] </thought>
|
||||||
|
|
||||||
|
Step 3: Action Planning
|
||||||
|
|
||||||
|
Based on your thoughts and analysis, plan a series of actions to solve the problem or complete the task.
|
||||||
|
|
||||||
|
Detail the steps you intend to take, resources you will use, and how these actions will address the key elements identified in your observations.
|
||||||
|
|
||||||
|
Use <action>...</action> tokens to encapsulate your action plan.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
<action> [List the specific actions you plan to take, including any steps to gather more information or implement a solution.] </action>
|
||||||
|
|
||||||
|
Step 4: Execute and Reflect
|
||||||
|
|
||||||
|
Implement your action plan. As you proceed, continue to observe and think, adjusting your actions as needed.
|
||||||
|
|
||||||
|
Reflect on the effectiveness of your actions and the outcome. Consider what worked well and what could be improved.
|
||||||
|
|
||||||
|
Use <observation>...</observation>, <thought>...</thought>, and <action>...</action> tokens as needed to describe this ongoing process.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
<observation> [New observations during action implementation.] </observation>
|
||||||
|
<thought> [Thoughts on how the actions are affecting the situation, adjustments needed, etc.] </thought>
|
||||||
|
<action> [Adjusted or continued actions to complete the task.] </action>
|
||||||
|
|
||||||
|
Guidance:
|
||||||
|
Remember, your goal is to provide a transparent and logical process that leads from observation to effective action. Your responses should demonstrate clear thinking, an understanding of the problem, and a rational approach to solving it. The use of tokens helps to structure your response and clarify the different stages of your reasoning and action.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return PROMPT
|
@ -1,5 +1,433 @@
|
|||||||
"""
|
import json
|
||||||
Base Structure for all Swarm Structures
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Any, Dict, List
|
||||||
|
from datetime import datetime
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gzip
|
||||||
|
except ImportError as error:
|
||||||
|
print(f"Error importing gzip: {error}")
|
||||||
|
|
||||||
"""
|
|
||||||
|
class BaseStructure(ABC):
|
||||||
|
"""Base structure.
|
||||||
|
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (Optional[str]): _description_
|
||||||
|
description (Optional[str]): _description_
|
||||||
|
save_metadata (bool): _description_
|
||||||
|
save_artifact_path (Optional[str]): _description_
|
||||||
|
save_metadata_path (Optional[str]): _description_
|
||||||
|
save_error_path (Optional[str]): _description_
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run: _description_
|
||||||
|
save_to_file: _description_
|
||||||
|
load_from_file: _description_
|
||||||
|
save_metadata: _description_
|
||||||
|
load_metadata: _description_
|
||||||
|
log_error: _description_
|
||||||
|
save_artifact: _description_
|
||||||
|
load_artifact: _description_
|
||||||
|
log_event: _description_
|
||||||
|
run_async: _description_
|
||||||
|
save_metadata_async: _description_
|
||||||
|
load_metadata_async: _description_
|
||||||
|
log_error_async: _description_
|
||||||
|
save_artifact_async: _description_
|
||||||
|
load_artifact_async: _description_
|
||||||
|
log_event_async: _description_
|
||||||
|
asave_to_file: _description_
|
||||||
|
aload_from_file: _description_
|
||||||
|
run_in_thread: _description_
|
||||||
|
save_metadata_in_thread: _description_
|
||||||
|
run_concurrent: _description_
|
||||||
|
compress_data: _description_
|
||||||
|
decompres_data: _description_
|
||||||
|
run_batched: _description_
|
||||||
|
load_config: _description_
|
||||||
|
backup_data: _description_
|
||||||
|
monitor_resources: _description_
|
||||||
|
run_with_resources: _description_
|
||||||
|
run_with_resources_batched: _description_
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
save_metadata: bool = True,
|
||||||
|
save_artifact_path: Optional[str] = "./artifacts",
|
||||||
|
save_metadata_path: Optional[str] = "./metadata",
|
||||||
|
save_error_path: Optional[str] = "./errors",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.save_metadata = save_metadata
|
||||||
|
self.save_artifact_path = save_artifact_path
|
||||||
|
self.save_metadata_path = save_metadata_path
|
||||||
|
self.save_error_path = save_error_path
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
"""Run the structure."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_to_file(self, data: Any, file_path: str):
|
||||||
|
"""Save data to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Any): _description_
|
||||||
|
file_path (str): _description_
|
||||||
|
"""
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
json.dump(data, file)
|
||||||
|
|
||||||
|
def load_from_file(self, file_path: str) -> Any:
|
||||||
|
"""Load data from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: _description_
|
||||||
|
"""
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
def save_metadata(self, metadata: Dict[str, Any]):
|
||||||
|
"""Save metadata to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (Dict[str, Any]): _description_
|
||||||
|
"""
|
||||||
|
if self.save_metadata:
|
||||||
|
file_path = os.path.join(
|
||||||
|
self.save_metadata_path, f"{self.name}_metadata.json"
|
||||||
|
)
|
||||||
|
self.save_to_file(metadata, file_path)
|
||||||
|
|
||||||
|
def load_metadata(self) -> Dict[str, Any]:
|
||||||
|
"""Load metadata from file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: _description_
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(
|
||||||
|
self.save_metadata_path, f"{self.name}_metadata.json"
|
||||||
|
)
|
||||||
|
return self.load_from_file(file_path)
|
||||||
|
|
||||||
|
def log_error(self, error_message: str):
|
||||||
|
"""Log error to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message (str): _description_
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(
|
||||||
|
self.save_error_path, f"{self.name}_errors.log"
|
||||||
|
)
|
||||||
|
with open(file_path, "a") as file:
|
||||||
|
file.write(f"{error_message}\n")
|
||||||
|
|
||||||
|
def save_artifact(self, artifact: Any, artifact_name: str):
|
||||||
|
"""Save artifact to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact (Any): _description_
|
||||||
|
artifact_name (str): _description_
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(
|
||||||
|
self.save_artifact_path, f"{artifact_name}.json"
|
||||||
|
)
|
||||||
|
self.save_to_file(artifact, file_path)
|
||||||
|
|
||||||
|
def load_artifact(self, artifact_name: str) -> Any:
|
||||||
|
"""Load artifact from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_name (str): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: _description_
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(
|
||||||
|
self.save_artifact_path, f"{artifact_name}.json"
|
||||||
|
)
|
||||||
|
return self.load_from_file(file_path)
|
||||||
|
|
||||||
|
def _current_timestamp(self):
|
||||||
|
"""Current timestamp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
def log_event(
|
||||||
|
self,
|
||||||
|
event: str,
|
||||||
|
event_type: str = "INFO",
|
||||||
|
):
|
||||||
|
"""Log event to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (str): _description_
|
||||||
|
event_type (str, optional): _description_. Defaults to "INFO".
|
||||||
|
"""
|
||||||
|
timestamp = self._current_timestamp()
|
||||||
|
log_message = f"[{timestamp}] [{event_type}] {event}\n"
|
||||||
|
file = os.path.join(
|
||||||
|
self.save_metadata_path, f"{self.name}_events.log"
|
||||||
|
)
|
||||||
|
with open(file, "a") as file:
|
||||||
|
file.write(log_message)
|
||||||
|
|
||||||
|
async def run_async(self, *args, **kwargs):
|
||||||
|
"""Run the structure asynchronously."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self.run, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_metadata_async(self, metadata: Dict[str, Any]):
|
||||||
|
"""Save metadata to file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (Dict[str, Any]): _description_
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self.save_metadata, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_metadata_async(self) -> Dict[str, Any]:
|
||||||
|
"""Load metadata from file asynchronously.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: _description_
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self.load_metadata)
|
||||||
|
|
||||||
|
async def log_error_async(self, error_message: str):
|
||||||
|
"""Log error to file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message (str): _description_
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self.log_error, error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_artifact_async(
|
||||||
|
self, artifact: Any, artifact_name: str
|
||||||
|
):
|
||||||
|
"""Save artifact to file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact (Any): _description_
|
||||||
|
artifact_name (str): _description_
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self.save_artifact, artifact, artifact_name
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_artifact_async(self, artifact_name: str) -> Any:
|
||||||
|
"""Load artifact from file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_name (str): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: _description_
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self.load_artifact, artifact_name
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_event_async(
|
||||||
|
self,
|
||||||
|
event: str,
|
||||||
|
event_type: str = "INFO",
|
||||||
|
):
|
||||||
|
"""Log event to file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (str): _description_
|
||||||
|
event_type (str, optional): _description_. Defaults to "INFO".
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, self.log_event, event, event_type
|
||||||
|
)
|
||||||
|
|
||||||
|
async def asave_to_file(
|
||||||
|
self, data: Any, file: str, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""Save data to file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Any): _description_
|
||||||
|
file (str): _description_
|
||||||
|
"""
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.save_to_file,
|
||||||
|
data,
|
||||||
|
file,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aload_from_file(
|
||||||
|
self,
|
||||||
|
file: str,
|
||||||
|
) -> Any:
|
||||||
|
"""Async load data from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: _description_
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.load_from_file, file)
|
||||||
|
|
||||||
|
def run_in_thread(self, *args, **kwargs):
|
||||||
|
"""Run the structure in a thread."""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
return executor.submit(self.run, *args, **kwargs)
|
||||||
|
|
||||||
|
def save_metadata_in_thread(self, metadata: Dict[str, Any]):
|
||||||
|
"""Save metadata to file in a thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (Dict[str, Any]): _description_
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
return executor.submit(self.save_metadata, metadata)
|
||||||
|
|
||||||
|
def run_concurrent(self, *args, **kwargs):
|
||||||
|
"""Run the structure concurrently."""
|
||||||
|
return asyncio.run(self.run_async(*args, **kwargs))
|
||||||
|
|
||||||
|
def compress_data(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
) -> bytes:
|
||||||
|
"""Compress data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Any): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: _description_
|
||||||
|
"""
|
||||||
|
return gzip.compress(json.dumps(data).encode())
|
||||||
|
|
||||||
|
def decompres_data(self, data: bytes) -> Any:
|
||||||
|
"""Decompress data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (bytes): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: _description_
|
||||||
|
"""
|
||||||
|
return json.loads(gzip.decompress(data).decode())
|
||||||
|
|
||||||
|
def run_batched(
|
||||||
|
self,
|
||||||
|
batched_data: List[Any],
|
||||||
|
batch_size: int = 10,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Run batched data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batched_data (List[Any]): _description_
|
||||||
|
batch_size (int, optional): _description_. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor(max_workers=batch_size) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.run, data)
|
||||||
|
for data in batched_data
|
||||||
|
]
|
||||||
|
return [future.result() for future in futures]
|
||||||
|
|
||||||
|
def load_config(
|
||||||
|
self, config: str = None, *args, **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Load config from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (str, optional): _description_. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: _description_
|
||||||
|
"""
|
||||||
|
return self.load_from_file(config)
|
||||||
|
|
||||||
|
def backup_data(
|
||||||
|
self, data: Any, backup_path: str = None, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""Backup data to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Any): _description_
|
||||||
|
backup_path (str, optional): _description_. Defaults to None.
|
||||||
|
"""
|
||||||
|
timestamp = self._current_timestamp()
|
||||||
|
backup_file_path = f"{backup_path}/{timestamp}.json"
|
||||||
|
self.save_to_file(data, backup_file_path)
|
||||||
|
|
||||||
|
def monitor_resources(self):
|
||||||
|
"""Monitor resource usage."""
|
||||||
|
memory = psutil.virtual_memory().percent
|
||||||
|
cpu_usage = psutil.cpu_percent(interval=1)
|
||||||
|
self.log_event(
|
||||||
|
f"Resource usage - Memory: {memory}%, CPU: {cpu_usage}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_with_resources(self, *args, **kwargs):
|
||||||
|
"""Run the structure with resource monitoring."""
|
||||||
|
self.monitor_resources()
|
||||||
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
|
def run_with_resources_batched(
|
||||||
|
self,
|
||||||
|
batched_data: List[Any],
|
||||||
|
batch_size: int = 10,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Run batched data with resource monitoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batched_data (List[Any]): _description_
|
||||||
|
batch_size (int, optional): _description_. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
self.monitor_resources()
|
||||||
|
return self.run_batched(
|
||||||
|
batched_data, batch_size, *args, **kwargs
|
||||||
|
)
|
||||||
|
@ -0,0 +1,218 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from swarms.models.gemini import Gemini
|
||||||
|
|
||||||
|
|
||||||
|
# Define test fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gemini_api_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("GEMINI_API_KEY", "mocked-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_genai_model():
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
|
||||||
|
# Test initialization of Gemini
|
||||||
|
def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model):
|
||||||
|
model = Gemini()
|
||||||
|
assert model.model_name == "gemini-pro"
|
||||||
|
assert model.gemini_api_key == "mocked-api-key"
|
||||||
|
assert model.model is mock_genai_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_init_custom_params(
|
||||||
|
mock_gemini_api_key, mock_genai_model
|
||||||
|
):
|
||||||
|
model = Gemini(
|
||||||
|
model_name="custom-model", gemini_api_key="custom-api-key"
|
||||||
|
)
|
||||||
|
assert model.model_name == "custom-model"
|
||||||
|
assert model.gemini_api_key == "custom-api-key"
|
||||||
|
assert model.model is mock_genai_model
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini run method
|
||||||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||||||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||||
|
def test_gemini_run_with_img(
|
||||||
|
mock_generate_content,
|
||||||
|
mock_process_img,
|
||||||
|
mock_gemini_api_key,
|
||||||
|
mock_genai_model,
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
task = "A cat"
|
||||||
|
img = "cat.png"
|
||||||
|
response_mock = Mock(text="Generated response")
|
||||||
|
mock_generate_content.return_value = response_mock
|
||||||
|
mock_process_img.return_value = "Processed image"
|
||||||
|
|
||||||
|
response = model.run(task=task, img=img)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
mock_generate_content.assert_called_with(
|
||||||
|
content=[task, "Processed image"]
|
||||||
|
)
|
||||||
|
mock_process_img.assert_called_with(img=img)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||||
|
def test_gemini_run_without_img(
|
||||||
|
mock_generate_content, mock_gemini_api_key, mock_genai_model
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
task = "A cat"
|
||||||
|
response_mock = Mock(text="Generated response")
|
||||||
|
mock_generate_content.return_value = response_mock
|
||||||
|
|
||||||
|
response = model.run(task=task)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
mock_generate_content.assert_called_with(task=task)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||||
|
def test_gemini_run_exception(
|
||||||
|
mock_generate_content, mock_gemini_api_key, mock_genai_model
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
task = "A cat"
|
||||||
|
mock_generate_content.side_effect = Exception("Test exception")
|
||||||
|
|
||||||
|
response = model.run(task=task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini process_img method
|
||||||
|
def test_gemini_process_img(mock_gemini_api_key, mock_genai_model):
|
||||||
|
model = Gemini(gemini_api_key="custom-api-key")
|
||||||
|
img = "cat.png"
|
||||||
|
img_data = b"Mocked image data"
|
||||||
|
|
||||||
|
with patch("builtins.open", create=True) as open_mock:
|
||||||
|
open_mock.return_value.__enter__.return_value.read.return_value = (
|
||||||
|
img_data
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_img = model.process_img(img)
|
||||||
|
|
||||||
|
assert processed_img == [
|
||||||
|
{"mime_type": "image/png", "data": img_data}
|
||||||
|
]
|
||||||
|
open_mock.assert_called_with(img, "rb")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini initialization with missing API key
|
||||||
|
def test_gemini_init_missing_api_key():
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Please provide a Gemini API key"
|
||||||
|
):
|
||||||
|
model = Gemini(gemini_api_key=None)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini initialization with missing model name
|
||||||
|
def test_gemini_init_missing_model_name():
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Please provide a model name"
|
||||||
|
):
|
||||||
|
model = Gemini(model_name=None)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini run method with empty task
|
||||||
|
def test_gemini_run_empty_task(mock_gemini_api_key, mock_genai_model):
|
||||||
|
model = Gemini()
|
||||||
|
task = ""
|
||||||
|
response = model.run(task=task)
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini run method with empty image
|
||||||
|
def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model):
|
||||||
|
model = Gemini()
|
||||||
|
task = "A cat"
|
||||||
|
img = ""
|
||||||
|
response = model.run(task=task, img=img)
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini process_img method with missing image
|
||||||
|
def test_gemini_process_img_missing_image(
|
||||||
|
mock_gemini_api_key, mock_genai_model
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
img = None
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Please provide an image to process"
|
||||||
|
):
|
||||||
|
model.process_img(img=img)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini process_img method with missing image type
|
||||||
|
def test_gemini_process_img_missing_image_type(
|
||||||
|
mock_gemini_api_key, mock_genai_model
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
img = "cat.png"
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Please provide the image type"
|
||||||
|
):
|
||||||
|
model.process_img(img=img, type=None)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini process_img method with missing Gemini API key
|
||||||
|
def test_gemini_process_img_missing_api_key(mock_genai_model):
|
||||||
|
model = Gemini(gemini_api_key=None)
|
||||||
|
img = "cat.png"
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Please provide a Gemini API key"
|
||||||
|
):
|
||||||
|
model.process_img(img=img, type="image/png")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini run method with mocked image processing
|
||||||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||||||
|
def test_gemini_run_mock_img_processing(
|
||||||
|
mock_process_img,
|
||||||
|
mock_generate_content,
|
||||||
|
mock_gemini_api_key,
|
||||||
|
mock_genai_model,
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
task = "A cat"
|
||||||
|
img = "cat.png"
|
||||||
|
response_mock = Mock(text="Generated response")
|
||||||
|
mock_generate_content.return_value = response_mock
|
||||||
|
mock_process_img.return_value = "Processed image"
|
||||||
|
|
||||||
|
response = model.run(task=task, img=img)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
mock_generate_content.assert_called_with(
|
||||||
|
content=[task, "Processed image"]
|
||||||
|
)
|
||||||
|
mock_process_img.assert_called_with(img=img)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Gemini run method with mocked image processing and exception
|
||||||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||||||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||||
|
def test_gemini_run_mock_img_processing_exception(
|
||||||
|
mock_generate_content,
|
||||||
|
mock_process_img,
|
||||||
|
mock_gemini_api_key,
|
||||||
|
mock_genai_model,
|
||||||
|
):
|
||||||
|
model = Gemini()
|
||||||
|
task = "A cat"
|
||||||
|
img = "cat.png"
|
||||||
|
mock_process_img.side_effect = Exception("Test exception")
|
||||||
|
|
||||||
|
response = model.run(task=task, img=img)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
mock_generate_content.assert_not_called()
|
||||||
|
mock_process_img.assert_called_with(img=img)
|
@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from swarms.models.together import TogetherModel
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_api_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_defaults():
|
||||||
|
model = TogetherModel()
|
||||||
|
assert model.together_api_key == "mocked-api-key"
|
||||||
|
assert model.logging_enabled is False
|
||||||
|
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
assert model.max_workers == 10
|
||||||
|
assert model.max_tokens == 300
|
||||||
|
assert model.api_endpoint == "https://api.together.xyz"
|
||||||
|
assert model.beautify is False
|
||||||
|
assert model.streaming_enabled is False
|
||||||
|
assert model.meta_prompt is False
|
||||||
|
assert model.system_prompt is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_custom_params(mock_api_key):
|
||||||
|
model = TogetherModel(
|
||||||
|
together_api_key="custom-api-key",
|
||||||
|
logging_enabled=True,
|
||||||
|
model_name="custom-model",
|
||||||
|
max_workers=5,
|
||||||
|
max_tokens=500,
|
||||||
|
api_endpoint="https://custom-api.together.xyz",
|
||||||
|
beautify=True,
|
||||||
|
streaming_enabled=True,
|
||||||
|
meta_prompt="meta-prompt",
|
||||||
|
system_prompt="system-prompt",
|
||||||
|
)
|
||||||
|
assert model.together_api_key == "custom-api-key"
|
||||||
|
assert model.logging_enabled is True
|
||||||
|
assert model.model_name == "custom-model"
|
||||||
|
assert model.max_workers == 5
|
||||||
|
assert model.max_tokens == 500
|
||||||
|
assert model.api_endpoint == "https://custom-api.together.xyz"
|
||||||
|
assert model.beautify is True
|
||||||
|
assert model.streaming_enabled is True
|
||||||
|
assert model.meta_prompt == "meta-prompt"
|
||||||
|
assert model.system_prompt == "system-prompt"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_success(mock_post, mock_api_key):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Generated response"}}]
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_failure(mock_post, mock_api_key):
|
||||||
|
mock_post.side_effect = requests.exceptions.RequestException(
|
||||||
|
"Request failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_logging_enabled(caplog, mock_api_key):
|
||||||
|
model = TogetherModel(logging_enabled=True)
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
|
||||||
|
with caplog.at_level(logging.DEBUG):
|
||||||
|
model.run(task)
|
||||||
|
|
||||||
|
assert "Sending request to" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_input", [None, 123, ["list", "of", "items"]]
|
||||||
|
)
|
||||||
|
def test_invalid_task_input(invalid_input, mock_api_key):
|
||||||
|
model = TogetherModel()
|
||||||
|
response = model.run(invalid_input)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_streaming_enabled(mock_post, mock_api_key):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Generated response"}}]
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
model = TogetherModel(streaming_enabled=True)
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_empty_choices(mock_post, mock_api_key):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"choices": []}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_with_exception(mock_post, mock_api_key):
|
||||||
|
mock_post.side_effect = Exception("Test exception")
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_logging_disabled(monkeypatch):
|
||||||
|
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||||
|
model = TogetherModel()
|
||||||
|
assert model.logging_enabled is False
|
||||||
|
assert not model.system_prompt
|
@ -0,0 +1,287 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from swarms.swarms.base import BaseStructure
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseStructure:
|
||||||
|
def test_init(self):
|
||||||
|
base_structure = BaseStructure(
|
||||||
|
name="TestStructure",
|
||||||
|
description="Test description",
|
||||||
|
save_metadata=True,
|
||||||
|
save_artifact_path="./test_artifacts",
|
||||||
|
save_metadata_path="./test_metadata",
|
||||||
|
save_error_path="./test_errors",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert base_structure.name == "TestStructure"
|
||||||
|
assert base_structure.description == "Test description"
|
||||||
|
assert base_structure.save_metadata is True
|
||||||
|
assert base_structure.save_artifact_path == "./test_artifacts"
|
||||||
|
assert base_structure.save_metadata_path == "./test_metadata"
|
||||||
|
assert base_structure.save_error_path == "./test_errors"
|
||||||
|
|
||||||
|
def test_save_to_file_and_load_from_file(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
file_path = os.path.join(tmp_dir, "test_file.json")
|
||||||
|
|
||||||
|
data_to_save = {"key": "value"}
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
base_structure.save_to_file(data_to_save, file_path)
|
||||||
|
loaded_data = base_structure.load_from_file(file_path)
|
||||||
|
|
||||||
|
assert loaded_data == data_to_save
|
||||||
|
|
||||||
|
def test_save_metadata_and_load_metadata(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_metadata_path=tmp_dir)
|
||||||
|
|
||||||
|
metadata = {"name": "Test", "description": "Test metadata"}
|
||||||
|
base_structure.save_metadata(metadata)
|
||||||
|
loaded_metadata = base_structure.load_metadata()
|
||||||
|
|
||||||
|
assert loaded_metadata == metadata
|
||||||
|
|
||||||
|
def test_log_error(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_error_path=tmp_dir)
|
||||||
|
|
||||||
|
error_message = "Test error message"
|
||||||
|
base_structure.log_error(error_message)
|
||||||
|
|
||||||
|
log_file = os.path.join(tmp_dir, "TestStructure_errors.log")
|
||||||
|
with open(log_file, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert lines[0] == f"{error_message}\n"
|
||||||
|
|
||||||
|
def test_save_artifact_and_load_artifact(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_artifact_path=tmp_dir)
|
||||||
|
|
||||||
|
artifact = {"key": "value"}
|
||||||
|
artifact_name = "test_artifact"
|
||||||
|
base_structure.save_artifact(artifact, artifact_name)
|
||||||
|
loaded_artifact = base_structure.load_artifact(artifact_name)
|
||||||
|
|
||||||
|
assert loaded_artifact == artifact
|
||||||
|
|
||||||
|
def test_current_timestamp(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
timestamp = base_structure._current_timestamp()
|
||||||
|
assert timestamp == current_time
|
||||||
|
|
||||||
|
def test_log_event(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_metadata_path=tmp_dir)
|
||||||
|
|
||||||
|
event = "Test event"
|
||||||
|
event_type = "INFO"
|
||||||
|
base_structure.log_event(event, event_type)
|
||||||
|
|
||||||
|
log_file = os.path.join(tmp_dir, "TestStructure_events.log")
|
||||||
|
with open(log_file, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert (
|
||||||
|
lines[0]
|
||||||
|
== f"[{base_structure._current_timestamp()}]"
|
||||||
|
f" [{event_type}] {event}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
async def async_function():
|
||||||
|
return "Async Test Result"
|
||||||
|
|
||||||
|
result = await base_structure.run_async(async_function)
|
||||||
|
assert result == "Async Test Result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_metadata_async(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_metadata_path=tmp_dir)
|
||||||
|
|
||||||
|
metadata = {"name": "Test", "description": "Test metadata"}
|
||||||
|
await base_structure.save_metadata_async(metadata)
|
||||||
|
loaded_metadata = base_structure.load_metadata()
|
||||||
|
|
||||||
|
assert loaded_metadata == metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_error_async(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_error_path=tmp_dir)
|
||||||
|
|
||||||
|
error_message = "Test error message"
|
||||||
|
await base_structure.log_error_async(error_message)
|
||||||
|
|
||||||
|
log_file = os.path.join(tmp_dir, "TestStructure_errors.log")
|
||||||
|
with open(log_file, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert lines[0] == f"{error_message}\n"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_artifact_async(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_artifact_path=tmp_dir)
|
||||||
|
|
||||||
|
artifact = {"key": "value"}
|
||||||
|
artifact_name = "test_artifact"
|
||||||
|
await base_structure.save_artifact_async(
|
||||||
|
artifact, artifact_name
|
||||||
|
)
|
||||||
|
loaded_artifact = base_structure.load_artifact(artifact_name)
|
||||||
|
|
||||||
|
assert loaded_artifact == artifact
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_artifact_async(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_artifact_path=tmp_dir)
|
||||||
|
|
||||||
|
artifact = {"key": "value"}
|
||||||
|
artifact_name = "test_artifact"
|
||||||
|
base_structure.save_artifact(artifact, artifact_name)
|
||||||
|
loaded_artifact = await base_structure.load_artifact_async(
|
||||||
|
artifact_name
|
||||||
|
)
|
||||||
|
|
||||||
|
assert loaded_artifact == artifact
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_event_async(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure(save_metadata_path=tmp_dir)
|
||||||
|
|
||||||
|
event = "Test event"
|
||||||
|
event_type = "INFO"
|
||||||
|
await base_structure.log_event_async(event, event_type)
|
||||||
|
|
||||||
|
log_file = os.path.join(tmp_dir, "TestStructure_events.log")
|
||||||
|
with open(log_file, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert (
|
||||||
|
lines[0]
|
||||||
|
== f"[{base_structure._current_timestamp()}]"
|
||||||
|
f" [{event_type}] {event}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_to_file(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
file_path = os.path.join(tmp_dir, "test_file.json")
|
||||||
|
data_to_save = {"key": "value"}
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
await base_structure.asave_to_file(data_to_save, file_path)
|
||||||
|
loaded_data = base_structure.load_from_file(file_path)
|
||||||
|
|
||||||
|
assert loaded_data == data_to_save
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aload_from_file(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
file_path = os.path.join(tmp_dir, "test_file.json")
|
||||||
|
data_to_save = {"key": "value"}
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
base_structure.save_to_file(data_to_save, file_path)
|
||||||
|
|
||||||
|
loaded_data = await base_structure.aload_from_file(file_path)
|
||||||
|
assert loaded_data == data_to_save
|
||||||
|
|
||||||
|
def test_run_in_thread(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
result = base_structure.run_in_thread(
|
||||||
|
lambda: "Thread Test Result"
|
||||||
|
)
|
||||||
|
assert result.result() == "Thread Test Result"
|
||||||
|
|
||||||
|
def test_save_and_decompress_data(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
data = {"key": "value"}
|
||||||
|
compressed_data = base_structure.compress_data(data)
|
||||||
|
decompressed_data = base_structure.decompres_data(
|
||||||
|
compressed_data
|
||||||
|
)
|
||||||
|
assert decompressed_data == data
|
||||||
|
|
||||||
|
def test_run_batched(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
def run_function(data):
|
||||||
|
return f"Processed {data}"
|
||||||
|
|
||||||
|
batched_data = list(range(10))
|
||||||
|
result = base_structure.run_batched(
|
||||||
|
batched_data, batch_size=5, func=run_function
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_result = [
|
||||||
|
f"Processed {data}" for data in batched_data
|
||||||
|
]
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
def test_load_config(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
config_file = os.path.join(tmp_dir, "config.json")
|
||||||
|
config_data = {"key": "value"}
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
base_structure.save_to_file(config_data, config_file)
|
||||||
|
loaded_config = base_structure.load_config(config_file)
|
||||||
|
|
||||||
|
assert loaded_config == config_data
|
||||||
|
|
||||||
|
def test_backup_data(self, tmpdir):
|
||||||
|
tmp_dir = tmpdir.mkdir("test_dir")
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
data_to_backup = {"key": "value"}
|
||||||
|
base_structure.backup_data(
|
||||||
|
data_to_backup, backup_path=tmp_dir
|
||||||
|
)
|
||||||
|
backup_files = os.listdir(tmp_dir)
|
||||||
|
|
||||||
|
assert len(backup_files) == 1
|
||||||
|
loaded_data = base_structure.load_from_file(
|
||||||
|
os.path.join(tmp_dir, backup_files[0])
|
||||||
|
)
|
||||||
|
assert loaded_data == data_to_backup
|
||||||
|
|
||||||
|
def test_monitor_resources(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
base_structure.monitor_resources()
|
||||||
|
|
||||||
|
def test_run_with_resources(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
def run_function():
|
||||||
|
base_structure.monitor_resources()
|
||||||
|
return "Resource Test Result"
|
||||||
|
|
||||||
|
result = base_structure.run_with_resources(run_function)
|
||||||
|
assert result == "Resource Test Result"
|
||||||
|
|
||||||
|
def test_run_with_resources_batched(self):
|
||||||
|
base_structure = BaseStructure()
|
||||||
|
|
||||||
|
def run_function(data):
|
||||||
|
base_structure.monitor_resources()
|
||||||
|
return f"Processed {data}"
|
||||||
|
|
||||||
|
batched_data = list(range(10))
|
||||||
|
result = base_structure.run_with_resources_batched(
|
||||||
|
batched_data, batch_size=5, func=run_function
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_result = [
|
||||||
|
f"Processed {data}" for data in batched_data
|
||||||
|
]
|
||||||
|
assert result == expected_result
|
Loading…
Reference in new issue