From c5ba940e477f804e68234d3d7c241a14981ddf0f Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 26 Dec 2023 19:15:09 -0500 Subject: [PATCH] [ModelParallelizer] --- docs/swarms/swarms/godmode.md | 64 ++++---- mkdocs.yml | 2 +- playground/swarms/godmode.py | 4 +- scripts/auto_tests_docs/auto_docs.py | 101 ++++++++++++ scripts/auto_tests_docs/auto_tests.py | 122 ++++++++++++++ scripts/auto_tests_docs/docs.py | 199 +++++++++++++++++++++++ scripts/auto_tests_docs/update_mkdocs.py | 60 +++++++ scripts/code_quality.sh | 16 +- scripts/delete_pycache.sh | 4 - scripts/get_package_requirements.py | 9 +- scripts/requirementstxt_to_pyproject.py | 5 +- scripts/test_name.sh | 1 + swarms/agents/simple_agent.py | 3 +- swarms/models/open_dalle.py | 66 ++++++++ swarms/prompts/__init__.py | 3 +- swarms/prompts/documentation.py | 4 +- swarms/prompts/tests.py | 131 ++++++++------- swarms/swarms/__init__.py | 4 +- swarms/swarms/god_mode.py | 14 +- tests/models/test_open_dalle.py | 59 +++++++ 20 files changed, 737 insertions(+), 134 deletions(-) create mode 100644 scripts/auto_tests_docs/auto_docs.py create mode 100644 scripts/auto_tests_docs/auto_tests.py create mode 100644 scripts/auto_tests_docs/docs.py create mode 100644 scripts/auto_tests_docs/update_mkdocs.py delete mode 100644 scripts/delete_pycache.sh create mode 100644 swarms/models/open_dalle.py create mode 100644 tests/models/test_open_dalle.py diff --git a/docs/swarms/swarms/godmode.md b/docs/swarms/swarms/godmode.md index a0965c94..2d903a8d 100644 --- a/docs/swarms/swarms/godmode.md +++ b/docs/swarms/swarms/godmode.md @@ -1,4 +1,4 @@ -# `GodMode` Documentation +# `ModelParallelizer` Documentation ## Table of Contents 1. [Understanding the Purpose](#understanding-the-purpose) @@ -11,19 +11,19 @@ ## 1. Understanding the Purpose -To create comprehensive documentation for the `GodMode` class, let's begin by understanding its purpose and functionality. +To create comprehensive documentation for the `ModelParallelizer` class, let's begin by understanding its purpose and functionality. ### Purpose and Functionality -`GodMode` is a class designed to facilitate the orchestration of multiple Language Model Models (LLMs) to perform various tasks simultaneously. It serves as a powerful tool for managing, distributing, and collecting responses from these models. +`ModelParallelizer` is a class designed to facilitate the orchestration of multiple Language Model Models (LLMs) to perform various tasks simultaneously. It serves as a powerful tool for managing, distributing, and collecting responses from these models. Key features and functionality include: -- **Parallel Task Execution**: `GodMode` can distribute tasks to multiple LLMs and execute them in parallel, improving efficiency and reducing response time. +- **Parallel Task Execution**: `ModelParallelizer` can distribute tasks to multiple LLMs and execute them in parallel, improving efficiency and reducing response time. - **Structured Response Presentation**: The class presents the responses from LLMs in a structured tabular format, making it easy for users to compare and analyze the results. -- **Task History Tracking**: `GodMode` keeps a record of tasks that have been submitted, allowing users to review previous tasks and responses. +- **Task History Tracking**: `ModelParallelizer` keeps a record of tasks that have been submitted, allowing users to review previous tasks and responses. - **Asynchronous Execution**: The class provides options for asynchronous task execution, which can be particularly useful for handling a large number of tasks. @@ -33,29 +33,29 @@ Now that we have an understanding of its purpose, let's proceed to provide a det ### Overview -The `GodMode` class is a crucial component for managing and utilizing multiple LLMs in various natural language processing (NLP) tasks. Its architecture and functionality are designed to address the need for parallel processing and efficient response handling. +The `ModelParallelizer` class is a crucial component for managing and utilizing multiple LLMs in various natural language processing (NLP) tasks. Its architecture and functionality are designed to address the need for parallel processing and efficient response handling. ### Importance and Relevance -In the rapidly evolving field of NLP, it has become common to use multiple language models to achieve better results in tasks such as translation, summarization, and question answering. `GodMode` streamlines this process by allowing users to harness the capabilities of several LLMs simultaneously. +In the rapidly evolving field of NLP, it has become common to use multiple language models to achieve better results in tasks such as translation, summarization, and question answering. `ModelParallelizer` streamlines this process by allowing users to harness the capabilities of several LLMs simultaneously. Key points: -- **Parallel Processing**: `GodMode` leverages multithreading to execute tasks concurrently, significantly reducing the time required for processing. +- **Parallel Processing**: `ModelParallelizer` leverages multithreading to execute tasks concurrently, significantly reducing the time required for processing. - **Response Visualization**: The class presents responses in a structured tabular format, enabling users to visualize and analyze the outputs from different LLMs. -- **Task Tracking**: Developers can track the history of tasks submitted to `GodMode`, making it easier to manage and monitor ongoing work. +- **Task Tracking**: Developers can track the history of tasks submitted to `ModelParallelizer`, making it easier to manage and monitor ongoing work. ### Architecture and How It Works -The architecture and working of `GodMode` can be summarized in four steps: +The architecture and working of `ModelParallelizer` can be summarized in four steps: -1. **Task Reception**: `GodMode` receives a task from the user. +1. **Task Reception**: `ModelParallelizer` receives a task from the user. 2. **Task Distribution**: The class distributes the task to all registered LLMs. -3. **Response Collection**: `GodMode` collects the responses generated by the LLMs. +3. **Response Collection**: `ModelParallelizer` collects the responses generated by the LLMs. 4. **Response Presentation**: Finally, the class presents the responses from all LLMs in a structured tabular format, making it easy for users to compare and analyze the results. @@ -65,15 +65,15 @@ Now that we have an overview, let's proceed with a detailed class definition. ### Class Attributes -- `llms`: A list of LLMs (Language Model Models) that `GodMode` manages. +- `llms`: A list of LLMs (Language Model Models) that `ModelParallelizer` manages. - `last_responses`: Stores the responses from the most recent task. -- `task_history`: Keeps a record of all tasks submitted to `GodMode`. +- `task_history`: Keeps a record of all tasks submitted to `ModelParallelizer`. ### Methods -The `GodMode` class defines various methods to facilitate task distribution, execution, and response presentation. Let's examine some of the key methods: +The `ModelParallelizer` class defines various methods to facilitate task distribution, execution, and response presentation. Let's examine some of the key methods: - `run(task)`: Distributes a task to all LLMs, collects responses, and returns them. @@ -87,22 +87,22 @@ The `GodMode` class defines various methods to facilitate task distribution, exe - `save_responses_to_file(filename)`: Saves responses to a file for future reference. -- `load_llms_from_file(filename)`: Loads LLMs from a file, making it easy to configure `GodMode` for different tasks. +- `load_llms_from_file(filename)`: Loads LLMs from a file, making it easy to configure `ModelParallelizer` for different tasks. - `get_task_history()`: Retrieves the task history, allowing users to review previous tasks. - `summary()`: Provides a summary of task history and the last responses, aiding in post-processing and analysis. -Now that we have covered the class definition, let's delve into the functionality and usage of `GodMode`. +Now that we have covered the class definition, let's delve into the functionality and usage of `ModelParallelizer`. ## 4. Functionality and Usage ### Distributing a Task and Collecting Responses -One of the primary use cases of `GodMode` is to distribute a task to all registered LLMs and collect their responses. This can be achieved using the `run(task)` method. Below is an example: +One of the primary use cases of `ModelParallelizer` is to distribute a task to all registered LLMs and collect their responses. This can be achieved using the `run(task)` method. Below is an example: ```python -god_mode = GodMode(llms) +god_mode = ModelParallelizer(llms) responses = god_mode.run("Translate the following English text to French: 'Hello, how are you?'") ``` @@ -124,7 +124,7 @@ god_mode.save_responses_to_file("responses.txt") ### Task History -The `GodMode` class keeps track of the task history. Developers can access the task history using the `get_task_history()` method. Example: +The `ModelParallelizer` class keeps track of the task history. Developers can access the task history using the `get_task_history()` method. Example: ```python task_history = god_mode.get_task_history() @@ -136,7 +136,7 @@ for i, task in enumerate(task_history): ### Parallel Execution -`GodMode` employs multithreading to execute tasks concurrently. This parallel processing capability significantly improves the efficiency of handling multiple tasks simultaneously. +`ModelParallelizer` employs multithreading to execute tasks concurrently. This parallel processing capability significantly improves the efficiency of handling multiple tasks simultaneously. ### Response Visualization @@ -144,13 +144,13 @@ The structured tabular format used for presenting responses simplifies the compa ## 6. Examples -Let's explore additional usage examples to illustrate the versatility of `GodMode` in handling various NLP tasks. +Let's explore additional usage examples to illustrate the versatility of `ModelParallelizer` in handling various NLP tasks. ### Example 1: Sentiment Analysis ```python from swarms.models import OpenAIChat -from swarms.swarms import GodMode +from swarms.swarms import ModelParallelizer from swarms.workers.worker import Worker # Create an instance of an LLM for sentiment analysis @@ -184,9 +184,9 @@ worker3 = Worker( temperature=0.5, ) -# Register the worker agents with GodMode +# Register the worker agents with ModelParallelizer agents = [worker1, worker2, worker3] -god_mode = GodMode(agents) +god_mode = ModelParallelizer(agents) # Task for sentiment analysis task = "Please analyze the sentiment of the following sentence: 'This movie is amazing!'" @@ -200,16 +200,16 @@ god_mode.print_responses(task) ```python from swarms.models import OpenAIChat -from swarms.swarms import GodMode +from swarms.swarms import ModelParallelizer # Define LLMs for translation tasks translator1 = OpenAIChat(model_name="translator-en-fr", openai_api_key="api-key", temperature=0.7) translator2 = OpenAIChat(model_name="translator-en-es", openai_api_key="api-key", temperature=0.7) translator3 = OpenAIChat(model_name="translator-en-de", openai_api_key="api-key", temperature=0.7) -# Register translation agents with GodMode +# Register translation agents with ModelParallelizer translators = [translator1, translator2, translator3] -god_mode = GodMode(translators) +god_mode = ModelParallelizer(translators) # Task for translation task = "Translate the following English text to French: 'Hello, how are you?'" @@ -223,7 +223,7 @@ god_mode.print_responses(task) ```python from swarms.models import OpenAIChat -from swarms.swarms import GodMode +from swarms.swarms import ModelParallelizer # Define LLMs for summarization tasks @@ -231,9 +231,9 @@ summarizer1 = OpenAIChat(model_name="summarizer-en", openai_api_key="api-key", t summarizer2 = OpenAIChat(model_name="summarizer-en", openai_api_key="api-key", temperature=0.6) summarizer3 = OpenAIChat(model_name="summarizer-en", openai_api_key="api-key", temperature=0.6) -# Register summarization agents with GodMode +# Register summarization agents with ModelParallelizer summarizers = [summarizer1, summarizer2, summarizer3] -god_mode = GodMode(summarizers) +god_mode = ModelParallelizer(summarizers) # Task for summarization task = "Summarize the main points of the article titled 'Climate Change and Its Impact on the Environment.'" @@ -244,6 +244,6 @@ god_mode.print_responses(task) ## 7. Conclusion -In conclusion, the `GodMode` class is a powerful tool for managing and orchestrating multiple Language Model Models in natural language processing tasks. Its ability to distribute tasks, collect responses, and present them in a structured format makes it invaluable for streamlining NLP workflows. By following the provided documentation, users can harness the full potential of `GodMode` to enhance their natural language processing projects. +In conclusion, the `ModelParallelizer` class is a powerful tool for managing and orchestrating multiple Language Model Models in natural language processing tasks. Its ability to distribute tasks, collect responses, and present them in a structured format makes it invaluable for streamlining NLP workflows. By following the provided documentation, users can harness the full potential of `ModelParallelizer` to enhance their natural language processing projects. For further information on specific LLMs or advanced usage, refer to the documentation of the respective models and their APIs. Additionally, external resources on parallel execution and response visualization can provide deeper insights into these topics. \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 0bca64c6..00d2080a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -63,7 +63,7 @@ nav: - Overview: "swarms/index.md" - swarms.swarms: - AbstractSwarm: "swarms/swarms/abstractswarm.md" - - GodMode: "swarms/swarms/godmode.md" + - ModelParallelizer: "swarms/swarms/ModelParallelizer.md" - Groupchat: "swarms/swarms/groupchat.md" - swarms.workers: - Overview: "swarms/workers/index.md" diff --git a/playground/swarms/godmode.py b/playground/swarms/godmode.py index f1269d98..4d18ef56 100644 --- a/playground/swarms/godmode.py +++ b/playground/swarms/godmode.py @@ -1,4 +1,4 @@ -from swarms.swarms import GodMode +from swarms.swarms import ModelParallelizer from swarms.models import OpenAIChat api_key = "" @@ -8,7 +8,7 @@ llm = OpenAIChat(openai_api_key=api_key) llms = [llm, llm, llm] -god_mode = GodMode(llms) +god_mode = ModelParallelizer(llms) task = "Generate a 10,000 word blog on health and wellness." diff --git a/scripts/auto_tests_docs/auto_docs.py b/scripts/auto_tests_docs/auto_docs.py new file mode 100644 index 00000000..d6e1060a --- /dev/null +++ b/scripts/auto_tests_docs/auto_docs.py @@ -0,0 +1,101 @@ +###### VERISON2 +import inspect +import os +import threading +from zeta import OpenAIChat +from scripts.auto_tests_docs.docs import DOCUMENTATION_WRITER_SOP +from zeta.nn.modules._activations import ( + AccurateGELUActivation, + ClippedGELUActivation, + FastGELUActivation, + GELUActivation, + LaplaceActivation, + LinearActivation, + MishActivation, + NewGELUActivation, + PytorchGELUTanh, + QuickGELUActivation, + ReLUSquaredActivation, +) +from zeta.nn.modules.dense_connect import DenseBlock +from zeta.nn.modules.dual_path_block import DualPathBlock +from zeta.nn.modules.feedback_block import FeedbackBlock +from zeta.nn.modules.highway_layer import HighwayLayer +from zeta.nn.modules.multi_scale_block import MultiScaleBlock +from zeta.nn.modules.recursive_block import RecursiveBlock +from dotenv import load_dotenv + +load_dotenv() + +api_key = os.getenv("OPENAI_API_KEY") + +model = OpenAIChat( + model_name="gpt-4", + openai_api_key=api_key, + max_tokens=4000, +) + + +def process_documentation(cls): + """ + Process the documentation for a given class using OpenAI model and save it in a Markdown file. + """ + doc = inspect.getdoc(cls) + source = inspect.getsource(cls) + input_content = ( + f"Class Name: {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource" + f" Code:\n{source}" + ) + print(input_content) + + # Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content) + processed_content = model(DOCUMENTATION_WRITER_SOP(input_content, "zeta")) + + doc_content = f"# {cls.__name__}\n\n{processed_content}\n" + + # Create the directory if it doesn't exist + dir_path = "docs/zeta/nn/modules" + os.makedirs(dir_path, exist_ok=True) + + # Write the processed documentation to a Markdown file + file_path = os.path.join(dir_path, f"{cls.__name__.lower()}.md") + with open(file_path, "w") as file: + file.write(doc_content) + + +def main(): + classes = [ + DenseBlock, + HighwayLayer, + MultiScaleBlock, + FeedbackBlock, + DualPathBlock, + RecursiveBlock, + PytorchGELUTanh, + NewGELUActivation, + GELUActivation, + FastGELUActivation, + QuickGELUActivation, + ClippedGELUActivation, + AccurateGELUActivation, + MishActivation, + LinearActivation, + LaplaceActivation, + ReLUSquaredActivation, + ] + + threads = [] + for cls in classes: + thread = threading.Thread(target=process_documentation, args=(cls,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + print("Documentation generated in 'docs/zeta/nn/modules' directory.") + + +if __name__ == "__main__": + main() diff --git a/scripts/auto_tests_docs/auto_tests.py b/scripts/auto_tests_docs/auto_tests.py new file mode 100644 index 00000000..70a3d750 --- /dev/null +++ b/scripts/auto_tests_docs/auto_tests.py @@ -0,0 +1,122 @@ +import inspect +import os +import re +import threading +from swarms import OpenAIChat +from scripts.auto_tests_docs.docs import TEST_WRITER_SOP_PROMPT +from zeta.nn.modules._activations import ( + AccurateGELUActivation, + ClippedGELUActivation, + FastGELUActivation, + GELUActivation, + LaplaceActivation, + LinearActivation, + MishActivation, + NewGELUActivation, + PytorchGELUTanh, + QuickGELUActivation, + ReLUSquaredActivation, +) +from zeta.nn.modules.dense_connect import DenseBlock +from zeta.nn.modules.dual_path_block import DualPathBlock +from zeta.nn.modules.feedback_block import FeedbackBlock +from zeta.nn.modules.highway_layer import HighwayLayer +from zeta.nn.modules.multi_scale_block import MultiScaleBlock +from zeta.nn.modules.recursive_block import RecursiveBlock +from dotenv import load_dotenv + +load_dotenv() + +api_key = os.getenv("OPENAI_API_KEY") + +model = OpenAIChat( + model_name="gpt-4", + openai_api_key=api_key, + max_tokens=4000, +) + + +def extract_code_from_markdown(markdown_content: str): + """ + Extracts code blocks from a Markdown string and returns them as a single string. + + Args: + - markdown_content (str): The Markdown content as a string. + + Returns: + - str: A single string containing all the code blocks separated by newlines. + """ + # Regular expression for fenced code blocks + pattern = r"```(?:\w+\n)?(.*?)```" + matches = re.findall(pattern, markdown_content, re.DOTALL) + + # Concatenate all code blocks separated by newlines + return "\n".join(code.strip() for code in matches) + + +def create_test(cls): + """ + Process the documentation for a given class using OpenAI model and save it in a Python file. + """ + doc = inspect.getdoc(cls) + source = inspect.getsource(cls) + input_content = ( + f"Class Name: {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource" + f" Code:\n{source}" + ) + print(input_content) + + # Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content) + processed_content = model( + TEST_WRITER_SOP_PROMPT(input_content, "zeta", "zeta.nn") + ) + processed_content = extract_code_from_markdown(processed_content) + + doc_content = f"# {cls.__name__}\n\n{processed_content}\n" + + # Create the directory if it doesn't exist + dir_path = "tests/nn/modules" + os.makedirs(dir_path, exist_ok=True) + + # Write the processed documentation to a Python file + file_path = os.path.join(dir_path, f"{cls.__name__.lower()}.py") + with open(file_path, "w") as file: + file.write(doc_content) + + +def main(): + classes = [ + DenseBlock, + HighwayLayer, + MultiScaleBlock, + FeedbackBlock, + DualPathBlock, + RecursiveBlock, + PytorchGELUTanh, + NewGELUActivation, + GELUActivation, + FastGELUActivation, + QuickGELUActivation, + ClippedGELUActivation, + AccurateGELUActivation, + MishActivation, + LinearActivation, + LaplaceActivation, + ReLUSquaredActivation, + ] + + threads = [] + for cls in classes: + thread = threading.Thread(target=create_test, args=(cls,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + print("Tests generated in 'docs/zeta/nn/modules' directory.") + + +if __name__ == "__main__": + main() diff --git a/scripts/auto_tests_docs/docs.py b/scripts/auto_tests_docs/docs.py new file mode 100644 index 00000000..684bf6dd --- /dev/null +++ b/scripts/auto_tests_docs/docs.py @@ -0,0 +1,199 @@ +def DOCUMENTATION_WRITER_SOP( + task: str, + module: str, +): + documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the {module} code below follow the outline for the {module} library, + provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words, + provide many usage examples and note this is markdown docs, create the documentation for the code to document, + put the arguments and methods in a table in markdown to make it visually seamless + + Now make the professional documentation for this code, provide the architecture and how the class works and why it works that way, + it's purpose, provide args, their types, 3 ways of usage examples, in examples show all the code like imports main example etc + + BE VERY EXPLICIT AND THOROUGH, MAKE IT DEEP AND USEFUL + + ######## + Step 1: Understand the purpose and functionality of the module or framework + + Read and analyze the description provided in the documentation to understand the purpose and functionality of the module or framework. + Identify the key features, parameters, and operations performed by the module or framework. + Step 2: Provide an overview and introduction + + Start the documentation by providing a brief overview and introduction to the module or framework. + Explain the importance and relevance of the module or framework in the context of the problem it solves. + Highlight any key concepts or terminology that will be used throughout the documentation. + Step 3: Provide a class or function definition + + Provide the class or function definition for the module or framework. + Include the parameters that need to be passed to the class or function and provide a brief description of each parameter. + Specify the data types and default values for each parameter. + Step 4: Explain the functionality and usage + + Provide a detailed explanation of how the module or framework works and what it does. + Describe the steps involved in using the module or framework, including any specific requirements or considerations. + Provide code examples to demonstrate the usage of the module or framework. + Explain the expected inputs and outputs for each operation or function. + Step 5: Provide additional information and tips + + Provide any additional information or tips that may be useful for using the module or framework effectively. + Address any common issues or challenges that developers may encounter and provide recommendations or workarounds. + Step 6: Include references and resources + + Include references to any external resources or research papers that provide further information or background on the module or framework. + Provide links to relevant documentation or websites for further exploration. + Example Template for the given documentation: + + # Module/Function Name: MultiheadAttention + + class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None): + ``` + Creates a multi-head attention module for joint information representation from the different subspaces. + + Parameters: + - embed_dim (int): Total dimension of the model. + - num_heads (int): Number of parallel attention heads. The embed_dim will be split across num_heads. + - dropout (float): Dropout probability on attn_output_weights. Default: 0.0 (no dropout). + - bias (bool): If specified, adds bias to input/output projection layers. Default: True. + - add_bias_kv (bool): If specified, adds bias to the key and value sequences at dim=0. Default: False. + - add_zero_attn (bool): If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False. + - kdim (int): Total number of features for keys. Default: None (uses kdim=embed_dim). + - vdim (int): Total number of features for values. Default: None (uses vdim=embed_dim). + - batch_first (bool): If True, the input and output tensors are provided as (batch, seq, feature). Default: False. + - device (torch.device): If specified, the tensors will be moved to the specified device. + - dtype (torch.dtype): If specified, the tensors will have the specified dtype. + ``` + + def forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False): + ``` + Forward pass of the multi-head attention module. + + Parameters: + - query (Tensor): Query embeddings of shape (L, E_q) for unbatched input, (L, N, E_q) when batch_first=False, or (N, L, E_q) when batch_first=True. + - key (Tensor): Key embeddings of shape (S, E_k) for unbatched input, (S, N, E_k) when batch_first=False, or (N, S, E_k) when batch_first=True. + - value (Tensor): Value embeddings of shape (S, E_v) for unbatched input, (S, N, E_v) when batch_first=False, or (N, S, E_v) when batch_first=True. + - key_padding_mask (Optional[Tensor]): If specified, a mask indicating elements to be ignored in key for attention computation. + - need_weights (bool): If specified, returns attention weights in addition to attention outputs. Default: True. + - attn_mask (Optional[Tensor]): If specified, a mask preventing attention to certain positions. + - average_attn_weights (bool): If true, returns averaged attention weights per head. Otherwise, returns attention weights separately per head. Note that this flag only has an effect when need_weights=True. Default: True. + - is_causal (bool): If specified, applies a causal mask as the attention mask. Default: False. + + Returns: + Tuple[Tensor, Optional[Tensor]]: + - attn_output (Tensor): Attention outputs of shape (L, E) for unbatched input, (L, N, E) when batch_first=False, or (N, L, E) when batch_first=True. + - attn_output_weights (Optional[Tensor]): Attention weights of shape (L, S) when unbatched or (N, L, S) when batched. Optional, only returned when need_weights=True. + ``` + + # Implementation of the forward pass of the attention module goes here + + return attn_output, attn_output_weights + + ``` + # Usage example: + + multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + attn_output, attn_output_weights = multihead_attn(query, key, value) + Note: + + The above template includes the class or function definition, parameters, description, and usage example. + To replicate the documentation for any other module or framework, follow the same structure and provide the specific details for that module or framework. + + + ############# DOCUMENT THE FOLLOWING CODE ######## + {task} + """ + return documentation + + +def TEST_WRITER_SOP_PROMPT(task: str, module: str, path: str, *args, **kwargs): + TESTS_PROMPT = f""" + + Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any + just write the best tests possible, the module is {module}, the file path is {path} + + + ######### TESTING GUIDE ############# + + # **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`** + + 1. **Preparation**: + - Install pytest: `pip install pytest`. + - Structure your project so that tests are in a separate `tests/` directory. + - Name your test files with the prefix `test_` for pytest to recognize them. + + 2. **Writing Basic Tests**: + - Use clear function names prefixed with `test_` (e.g., `test_check_value()`). + - Use assert statements to validate results. + + 3. **Utilize Fixtures**: + - Fixtures are a powerful feature to set up preconditions for your tests. + - Use `@pytest.fixture` decorator to define a fixture. + - Pass fixture name as an argument to your test to use it. + + 4. **Parameterized Testing**: + - Use `@pytest.mark.parametrize` to run a test multiple times with different inputs. + - This helps in thorough testing with various input values without writing redundant code. + + 5. **Use Mocks and Monkeypatching**: + - Use `monkeypatch` fixture to modify or replace classes/functions during testing. + - Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code. + + 6. **Exception Testing**: + - Test for expected exceptions using `pytest.raises(ExceptionType)`. + + 7. **Test Coverage**: + - Install pytest-cov: `pip install pytest-cov`. + - Run tests with `pytest --cov=my_module` to get a coverage report. + + 8. **Environment Variables and Secret Handling**: + - Store secrets and configurations in environment variables. + - Use libraries like `python-decouple` or `python-dotenv` to load environment variables. + - For tests, mock or set environment variables temporarily within the test environment. + + 9. **Grouping and Marking Tests**: + - Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`). + - This allows for selectively running certain groups of tests. + + 10. **Use Plugins**: + - Utilize the rich ecosystem of pytest plugins (e.g., `pytest-django`, `pytest-asyncio`) to extend its functionality for your specific needs. + + 11. **Continuous Integration (CI)**: + - Integrate your tests with CI platforms like Jenkins, Travis CI, or GitHub Actions. + - Ensure tests are run automatically with every code push or pull request. + + 12. **Logging and Reporting**: + - Use `pytest`'s inbuilt logging. + - Integrate with tools like `Allure` for more comprehensive reporting. + + 13. **Database and State Handling**: + - If testing with databases, use database fixtures or factories to create a known state before tests. + - Clean up and reset state post-tests to maintain consistency. + + 14. **Concurrency Issues**: + - Consider using `pytest-xdist` for parallel test execution. + - Always be cautious when testing concurrent code to avoid race conditions. + + 15. **Clean Code Practices**: + - Ensure tests are readable and maintainable. + - Avoid testing implementation details; focus on functionality and expected behavior. + + 16. **Regular Maintenance**: + - Periodically review and update tests. + - Ensure that tests stay relevant as your codebase grows and changes. + + 17. **Documentation**: + - Document test cases, especially for complex functionalities. + - Ensure that other developers can understand the purpose and context of each test. + + 18. **Feedback Loop**: + - Use test failures as feedback for development. + - Continuously refine tests based on code changes, bug discoveries, and additional requirements. + + By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project. + + + ######### CREATE TESTS FOR THIS CODE: ####### + {task} + + """ + + return TESTS_PROMPT diff --git a/scripts/auto_tests_docs/update_mkdocs.py b/scripts/auto_tests_docs/update_mkdocs.py new file mode 100644 index 00000000..4901059f --- /dev/null +++ b/scripts/auto_tests_docs/update_mkdocs.py @@ -0,0 +1,60 @@ +import yaml + + +def update_mkdocs( + class_names, base_path="docs/zeta/nn/modules", mkdocs_file="mkdocs.yml" +): + """ + Update the mkdocs.yml file with new documentation links. + + Args: + - class_names: A list of class names for which documentation is generated. + - base_path: The base path where documentation Markdown files are stored. + - mkdocs_file: The path to the mkdocs.yml file. + """ + with open(mkdocs_file, "r") as file: + mkdocs_config = yaml.safe_load(file) + + # Find or create the 'zeta.nn.modules' section in 'nav' + zeta_modules_section = None + for section in mkdocs_config.get("nav", []): + if "zeta.nn.modules" in section: + zeta_modules_section = section["zeta.nn.modules"] + break + + if zeta_modules_section is None: + zeta_modules_section = {} + mkdocs_config["nav"].append({"zeta.nn.modules": zeta_modules_section}) + + # Add the documentation paths to the 'zeta.nn.modules' section + for class_name in class_names: + doc_path = f"{base_path}/{class_name.lower()}.md" + zeta_modules_section[class_name] = doc_path + + # Write the updated content back to mkdocs.yml + with open(mkdocs_file, "w") as file: + yaml.safe_dump(mkdocs_config, file, sort_keys=False) + + +# Example usage +classes = [ + "DenseBlock", + "HighwayLayer", + "MultiScaleBlock", + "FeedbackBlock", + "DualPathBlock", + "RecursiveBlock", + "PytorchGELUTanh", + "NewGELUActivation", + "GELUActivation", + "FastGELUActivation", + "QuickGELUActivation", + "ClippedGELUActivation", + "AccurateGELUActivation", + "MishActivation", + "LinearActivation", + "LaplaceActivation", + "ReLUSquaredActivation", +] + +update_mkdocs(classes) diff --git a/scripts/code_quality.sh b/scripts/code_quality.sh index 90153258..e3afec13 100755 --- a/scripts/code_quality.sh +++ b/scripts/code_quality.sh @@ -1,19 +1,19 @@ #!/bin/bash -# Navigate to the directory containing the 'swarms' folder +# Navigate to the directory containing the 'tests' folder # cd /path/to/your/code/directory # Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i) -# on all Python files (*.py) under the 'swarms' directory. -autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes swarms/ +# on all Python files (*.py) under the 'tests' directory. +autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes zeta/ # Run black with default settings, since black does not have an aggressiveness level. -# Black will format all Python files it finds in the 'swarms' directory. -black --experimental-string-processing swarms/ +# Black will format all Python files it finds in the 'tests' directory. +black --experimental-string-processing zeta/ -# Run ruff on the 'swarms' directory. +# Run ruff on the 'tests' directory. # Add any additional flags if needed according to your version of ruff. -ruff --unsafe_fix +ruff zeta/ --fix # YAPF -yapf --recursive --in-place --verbose --style=google --parallel swarms +yapf --recursive --in-place --verbose --style=google --parallel tests diff --git a/scripts/delete_pycache.sh b/scripts/delete_pycache.sh deleted file mode 100644 index db11f239..00000000 --- a/scripts/delete_pycache.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -# Find all __pycache__ directories and delete them -find . -type d -name "__pycache__" -exec rm -rf {} + \ No newline at end of file diff --git a/scripts/get_package_requirements.py b/scripts/get_package_requirements.py index 9494409b..0d57c028 100644 --- a/scripts/get_package_requirements.py +++ b/scripts/get_package_requirements.py @@ -13,18 +13,13 @@ def get_package_versions(requirements_path, output_path): for requirement in requirements: # Skip empty lines and comments - if ( - requirement.strip() == "" - or requirement.strip().startswith("#") - ): + if requirement.strip() == "" or requirement.strip().startswith("#"): continue # Extract package name package_name = requirement.split("==")[0].strip() try: - version = pkg_resources.get_distribution( - package_name - ).version + version = pkg_resources.get_distribution(package_name).version package_versions.append(f"{package_name}=={version}") except pkg_resources.DistributionNotFound: package_versions.append(f"{package_name}: not installed") diff --git a/scripts/requirementstxt_to_pyproject.py b/scripts/requirementstxt_to_pyproject.py index 5710db61..59f6946f 100644 --- a/scripts/requirementstxt_to_pyproject.py +++ b/scripts/requirementstxt_to_pyproject.py @@ -10,10 +10,7 @@ def update_pyproject_versions(pyproject_path): print(f"Error: The file '{pyproject_path}' was not found.") return except toml.TomlDecodeError: - print( - f"Error: The file '{pyproject_path}' is not a valid TOML" - " file." - ) + print(f"Error: The file '{pyproject_path}' is not a valid TOML file.") return dependencies = ( diff --git a/scripts/test_name.sh b/scripts/test_name.sh index cdc6a013..4123f870 100755 --- a/scripts/test_name.sh +++ b/scripts/test_name.sh @@ -4,5 +4,6 @@ do dir=$(dirname "$file") if [[ $filename != test_* ]]; then mv "$file" "$dir/test_$filename" + printf "\e[1;34mRenamed: \e[0m$file \e[1;32mto\e[0m $dir/test_$filename\n" fi done \ No newline at end of file diff --git a/swarms/agents/simple_agent.py b/swarms/agents/simple_agent.py index 3e4a65ae..1c6d3126 100644 --- a/swarms/agents/simple_agent.py +++ b/swarms/agents/simple_agent.py @@ -1,4 +1,5 @@ -from swarms import Conversation, AbstractLLM +from swarms.structs.conversation import Conversation +from swarms.models.base_llm import AbstractLLM # Run the language model in a loop for n iterations diff --git a/swarms/models/open_dalle.py b/swarms/models/open_dalle.py new file mode 100644 index 00000000..b43d6c2e --- /dev/null +++ b/swarms/models/open_dalle.py @@ -0,0 +1,66 @@ +from typing import Optional, Any + +import torch +from diffusers import AutoPipelineForText2Image +from swarms.models.base_multimodal_model import BaseMultiModalModel + + +class OpenDalle(BaseMultiModalModel): + """OpenDalle model class + + Attributes: + model_name (str): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1". + torch_dtype (torch.dtype): The torch data type to be used. Defaults to torch.float16. + device (str): The device to be used for computation. Defaults to "cuda". + + Examples: + >>> from swarms.models.open_dalle import OpenDalle + >>> od = OpenDalle() + >>> od.run("A picture of a cat") + + """ + + def __init__( + self, + model_name: str = "dataautogpt3/OpenDalleV1.1", + torch_dtype: Any = torch.float16, + device: str = "cuda", + *args, + **kwargs, + ): + """ + Initializes the OpenDalle model. + + Args: + model_name (str, optional): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1". + torch_dtype (torch.dtype, optional): The torch data type to be used. Defaults to torch.float16. + device (str, optional): The device to be used for computation. Defaults to "cuda". + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + self.pipeline = AutoPipelineForText2Image.from_pretrained( + model_name, torch_dtype=torch_dtype, *args, **kwargs + ).to(device) + + def run(self, task: Optional[str] = None, *args, **kwargs): + """Run the OpenDalle model + + Args: + task (str, optional): The task to be performed. Defaults to None. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + [type]: [description] + """ + try: + if task is None: + raise ValueError("Task cannot be None") + if not isinstance(task, str): + raise TypeError("Task must be a string") + if len(task) < 1: + raise ValueError("Task cannot be empty") + return self.pipeline(task, *args, **kwargs).images[0] + except Exception as error: + print(f"[ERROR][OpenDalle] {error}") + raise error diff --git a/swarms/prompts/__init__.py b/swarms/prompts/__init__.py index 6417dc85..2961c10d 100644 --- a/swarms/prompts/__init__.py +++ b/swarms/prompts/__init__.py @@ -6,7 +6,7 @@ from swarms.prompts.operations_agent_prompt import ( OPERATIONS_AGENT_PROMPT, ) from swarms.prompts.product_agent_prompt import PRODUCT_AGENT_PROMPT - +from swarms.prompts.documentation import DOCUMENTATION_WRITER_SOP __all__ = [ "CODE_INTERPRETER", @@ -15,4 +15,5 @@ __all__ = [ "LEGAL_AGENT_PROMPT", "OPERATIONS_AGENT_PROMPT", "PRODUCT_AGENT_PROMPT", + "DOCUMENTATION_WRITER_SOP" ] diff --git a/swarms/prompts/documentation.py b/swarms/prompts/documentation.py index 3ed2eb8c..d3601014 100644 --- a/swarms/prompts/documentation.py +++ b/swarms/prompts/documentation.py @@ -1,5 +1,5 @@ -def documentation(task: str): - documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the code below follow the outline for the library, +def DOCUMENTATION_WRITER_SOP(task: str, module: str, ): + documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the {module} code below follow the outline for the {module} library, provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words, provide many usage examples and note this is markdown docs, create the documentation for the code to document, put the arguments and methods in a table in markdown to make it visually seamless diff --git a/swarms/prompts/tests.py b/swarms/prompts/tests.py index df1ee92d..c7e45c1f 100644 --- a/swarms/prompts/tests.py +++ b/swarms/prompts/tests.py @@ -1,89 +1,94 @@ -TESTS_PROMPT = """ +def TEST_WRITER_SOP_PROMPT(task: str, module: str, path: str, *args, **kwargs): -Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any -just write the best tests possible: + TESTS_PROMPT = f""" + Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any + just write the best tests possible, the module is {module}, the file path is {path} -######### TESTING GUIDE ############# -# **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`** + ######### TESTING GUIDE ############# -1. **Preparation**: - - Install pytest: `pip install pytest`. - - Structure your project so that tests are in a separate `tests/` directory. - - Name your test files with the prefix `test_` for pytest to recognize them. + # **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`** -2. **Writing Basic Tests**: - - Use clear function names prefixed with `test_` (e.g., `test_check_value()`). - - Use assert statements to validate results. + 1. **Preparation**: + - Install pytest: `pip install pytest`. + - Structure your project so that tests are in a separate `tests/` directory. + - Name your test files with the prefix `test_` for pytest to recognize them. -3. **Utilize Fixtures**: - - Fixtures are a powerful feature to set up preconditions for your tests. - - Use `@pytest.fixture` decorator to define a fixture. - - Pass fixture name as an argument to your test to use it. + 2. **Writing Basic Tests**: + - Use clear function names prefixed with `test_` (e.g., `test_check_value()`). + - Use assert statements to validate results. -4. **Parameterized Testing**: - - Use `@pytest.mark.parametrize` to run a test multiple times with different inputs. - - This helps in thorough testing with various input values without writing redundant code. + 3. **Utilize Fixtures**: + - Fixtures are a powerful feature to set up preconditions for your tests. + - Use `@pytest.fixture` decorator to define a fixture. + - Pass fixture name as an argument to your test to use it. -5. **Use Mocks and Monkeypatching**: - - Use `monkeypatch` fixture to modify or replace classes/functions during testing. - - Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code. + 4. **Parameterized Testing**: + - Use `@pytest.mark.parametrize` to run a test multiple times with different inputs. + - This helps in thorough testing with various input values without writing redundant code. -6. **Exception Testing**: - - Test for expected exceptions using `pytest.raises(ExceptionType)`. + 5. **Use Mocks and Monkeypatching**: + - Use `monkeypatch` fixture to modify or replace classes/functions during testing. + - Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code. -7. **Test Coverage**: - - Install pytest-cov: `pip install pytest-cov`. - - Run tests with `pytest --cov=my_module` to get a coverage report. + 6. **Exception Testing**: + - Test for expected exceptions using `pytest.raises(ExceptionType)`. -8. **Environment Variables and Secret Handling**: - - Store secrets and configurations in environment variables. - - Use libraries like `python-decouple` or `python-dotenv` to load environment variables. - - For tests, mock or set environment variables temporarily within the test environment. + 7. **Test Coverage**: + - Install pytest-cov: `pip install pytest-cov`. + - Run tests with `pytest --cov=my_module` to get a coverage report. -9. **Grouping and Marking Tests**: - - Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`). - - This allows for selectively running certain groups of tests. + 8. **Environment Variables and Secret Handling**: + - Store secrets and configurations in environment variables. + - Use libraries like `python-decouple` or `python-dotenv` to load environment variables. + - For tests, mock or set environment variables temporarily within the test environment. -10. **Use Plugins**: - - Utilize the rich ecosystem of pytest plugins (e.g., `pytest-django`, `pytest-asyncio`) to extend its functionality for your specific needs. + 9. **Grouping and Marking Tests**: + - Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`). + - This allows for selectively running certain groups of tests. -11. **Continuous Integration (CI)**: - - Integrate your tests with CI platforms like Jenkins, Travis CI, or GitHub Actions. - - Ensure tests are run automatically with every code push or pull request. + 10. **Use Plugins**: + - Utilize the rich ecosystem of pytest plugins (e.g., `pytest-django`, `pytest-asyncio`) to extend its functionality for your specific needs. -12. **Logging and Reporting**: - - Use `pytest`'s inbuilt logging. - - Integrate with tools like `Allure` for more comprehensive reporting. + 11. **Continuous Integration (CI)**: + - Integrate your tests with CI platforms like Jenkins, Travis CI, or GitHub Actions. + - Ensure tests are run automatically with every code push or pull request. -13. **Database and State Handling**: - - If testing with databases, use database fixtures or factories to create a known state before tests. - - Clean up and reset state post-tests to maintain consistency. + 12. **Logging and Reporting**: + - Use `pytest`'s inbuilt logging. + - Integrate with tools like `Allure` for more comprehensive reporting. -14. **Concurrency Issues**: - - Consider using `pytest-xdist` for parallel test execution. - - Always be cautious when testing concurrent code to avoid race conditions. + 13. **Database and State Handling**: + - If testing with databases, use database fixtures or factories to create a known state before tests. + - Clean up and reset state post-tests to maintain consistency. -15. **Clean Code Practices**: - - Ensure tests are readable and maintainable. - - Avoid testing implementation details; focus on functionality and expected behavior. + 14. **Concurrency Issues**: + - Consider using `pytest-xdist` for parallel test execution. + - Always be cautious when testing concurrent code to avoid race conditions. -16. **Regular Maintenance**: - - Periodically review and update tests. - - Ensure that tests stay relevant as your codebase grows and changes. + 15. **Clean Code Practices**: + - Ensure tests are readable and maintainable. + - Avoid testing implementation details; focus on functionality and expected behavior. -17. **Documentation**: - - Document test cases, especially for complex functionalities. - - Ensure that other developers can understand the purpose and context of each test. + 16. **Regular Maintenance**: + - Periodically review and update tests. + - Ensure that tests stay relevant as your codebase grows and changes. -18. **Feedback Loop**: - - Use test failures as feedback for development. - - Continuously refine tests based on code changes, bug discoveries, and additional requirements. + 17. **Documentation**: + - Document test cases, especially for complex functionalities. + - Ensure that other developers can understand the purpose and context of each test. -By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project. + 18. **Feedback Loop**: + - Use test failures as feedback for development. + - Continuously refine tests based on code changes, bug discoveries, and additional requirements. + By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project. -######### CREATE TESTS FOR THIS CODE: ####### -""" + ######### CREATE TESTS FOR THIS CODE: ####### + {task} + + """ + + return TESTS_PROMPT \ No newline at end of file diff --git a/swarms/swarms/__init__.py b/swarms/swarms/__init__.py index 38ced622..ac45a48e 100644 --- a/swarms/swarms/__init__.py +++ b/swarms/swarms/__init__.py @@ -1,11 +1,11 @@ from swarms.structs.autoscaler import AutoScaler -from swarms.swarms.god_mode import GodMode +from swarms.swarms.god_mode import ModelParallelizer from swarms.swarms.multi_agent_collab import MultiAgentCollaboration from swarms.swarms.base import AbstractSwarm __all__ = [ "AutoScaler", - "GodMode", + "ModelParallelizer", "MultiAgentCollaboration", "AbstractSwarm", ] diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index 29178b2c..dfe4e232 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -11,17 +11,17 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -class GodMode: +class ModelParallelizer: """ - GodMode + ModelParallelizer ----- Architecture: How it works: - 1. GodMode receives a task from the user. - 2. GodMode distributes the task to all LLMs. - 3. GodMode collects the responses from all LLMs. - 4. GodMode prints the responses from all LLMs. + 1. ModelParallelizer receives a task from the user. + 2. ModelParallelizer distributes the task to all LLMs. + 3. ModelParallelizer collects the responses from all LLMs. + 4. ModelParallelizer prints the responses from all LLMs. Parameters: llms: list of LLMs @@ -31,7 +31,7 @@ class GodMode: print_responses(task): print responses from all LLMs Usage: - god_mode = GodMode(llms) + god_mode = ModelParallelizer(llms) god_mode.run(task) god_mode.print_responses(task) diff --git a/tests/models/test_open_dalle.py b/tests/models/test_open_dalle.py new file mode 100644 index 00000000..2483d705 --- /dev/null +++ b/tests/models/test_open_dalle.py @@ -0,0 +1,59 @@ +import pytest +import torch +from swarms.models.open_dalle import OpenDalle + + +def test_init(): + od = OpenDalle() + assert isinstance(od, OpenDalle) + + +def test_init_custom_model(): + od = OpenDalle(model_name="custom_model") + assert od.pipeline.model_name == "custom_model" + + +def test_init_custom_dtype(): + od = OpenDalle(torch_dtype=torch.float32) + assert od.pipeline.torch_dtype == torch.float32 + + +def test_init_custom_device(): + od = OpenDalle(device="cpu") + assert od.pipeline.device == "cpu" + + +def test_run(): + od = OpenDalle() + result = od.run("A picture of a cat") + assert isinstance(result, torch.Tensor) + + +def test_run_no_task(): + od = OpenDalle() + with pytest.raises(ValueError, match="Task cannot be None"): + od.run(None) + + +def test_run_non_string_task(): + od = OpenDalle() + with pytest.raises(TypeError, match="Task must be a string"): + od.run(123) + + +def test_run_empty_task(): + od = OpenDalle() + with pytest.raises(ValueError, match="Task cannot be empty"): + od.run("") + + +def test_run_custom_args(): + od = OpenDalle() + result = od.run("A picture of a cat", custom_arg="custom_value") + assert isinstance(result, torch.Tensor) + + +def test_run_error(): + od = OpenDalle() + with pytest.raises(Exception): + od.run("A picture of a cat", raise_error=True)