From 51a271172f71f4be86f86748f93a03d16f626eae Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 19 Dec 2023 15:02:22 -0500 Subject: [PATCH] [CODE QUALITY] [MEMORY][New DB][SQLite] --- playground/demos/assembly/assembly.py | 3 - playground/demos/autotemp/autotemp_example.py | 1 - .../demos/gemini_benchmarking/gemini_chat.py | 7 +- .../demos/gemini_benchmarking/gemini_react.py | 4 +- playground/demos/nutrition/nutrition.py | 2 +- playground/demos/optimize_llm_stack/vortex.py | 1 - .../demos/swarm_of_mma_manufacturing/main.py | 1 - playground/models/kosmos2.py | 2 +- playground/tools/agent_with_tools.py | 1 - swarms/memory/__init__.py | 3 +- swarms/memory/qdrant.py | 2 +- swarms/memory/short_term_memory.py | 11 +- swarms/memory/sqlite.py | 120 ++++++++++++++++++ swarms/models/base_tts.py | 2 +- swarms/models/gemini.py | 5 +- swarms/models/simple_ada.py | 1 - swarms/structs/agent.py | 1 - swarms/structs/autoscaler.py | 2 - swarms/structs/base.py | 2 +- swarms/utils/apa.py | 2 - swarms/utils/loggers.py | 16 +-- swarms/utils/main.py | 1 - swarms/utils/pdf_to_text.py | 1 - tests/memory/test_short_term_memory.py | 1 - tests/memory/test_sqlite.py | 104 +++++++++++++++ tests/models/test_bioclip.py | 2 - tests/models/test_gemini.py | 4 +- tests/models/test_gpt4_vision_api.py | 16 +-- tests/models/test_multion.py | 2 +- tests/models/test_togther.py | 1 - tests/models/test_vllm.py | 2 +- tests/structs/test_agent.py | 4 +- tests/structs/test_autoscaler.py | 10 +- tests/structs/test_task.py | 3 +- tests/swarms/test_multi_agent_collab.py | 2 +- tests/upload_tests_to_issues.py | 2 - tests/utils/test_phoenix_handler.py | 2 - 37 files changed, 278 insertions(+), 68 deletions(-) create mode 100644 swarms/memory/sqlite.py create mode 100644 tests/memory/test_sqlite.py diff --git a/playground/demos/assembly/assembly.py b/playground/demos/assembly/assembly.py index b82e075c..704c80d4 100644 --- a/playground/demos/assembly/assembly.py +++ b/playground/demos/assembly/assembly.py @@ -1,8 +1,5 @@ from swarms.structs import Agent from swarms.models.gpt4_vision_api import GPT4VisionAPI -from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( - MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, -) llm = GPT4VisionAPI() diff --git a/playground/demos/autotemp/autotemp_example.py b/playground/demos/autotemp/autotemp_example.py index c5f86416..ccbd54c3 100644 --- a/playground/demos/autotemp/autotemp_example.py +++ b/playground/demos/autotemp/autotemp_example.py @@ -1,4 +1,3 @@ -from swarms.models import OpenAIChat from autotemp import AutoTemp # Your OpenAI API key diff --git a/playground/demos/gemini_benchmarking/gemini_chat.py b/playground/demos/gemini_benchmarking/gemini_chat.py index b1f12ee7..6d9dc7ae 100644 --- a/playground/demos/gemini_benchmarking/gemini_chat.py +++ b/playground/demos/gemini_benchmarking/gemini_chat.py @@ -21,5 +21,8 @@ model = Gemini( ) -out = model.chat("Create the code for a react component that displays a name", img=img) -print(out) \ No newline at end of file +out = model.chat( + "Create the code for a react component that displays a name", + img=img, +) +print(out) diff --git a/playground/demos/gemini_benchmarking/gemini_react.py b/playground/demos/gemini_benchmarking/gemini_react.py index 76caf974..022405e9 100644 --- a/playground/demos/gemini_benchmarking/gemini_react.py +++ b/playground/demos/gemini_benchmarking/gemini_react.py @@ -22,5 +22,7 @@ model = Gemini( # Run the model -out = model.run("Create the code for a react component that displays a name") +out = model.run( + "Create the code for a react component that displays a name" +) print(out) diff --git a/playground/demos/nutrition/nutrition.py b/playground/demos/nutrition/nutrition.py index aca079ba..428560e3 100644 --- a/playground/demos/nutrition/nutrition.py +++ b/playground/demos/nutrition/nutrition.py @@ -2,7 +2,7 @@ import os import base64 import requests from dotenv import load_dotenv -from swarms.models import Anthropic, OpenAIChat +from swarms.models import OpenAIChat from swarms.structs import Agent # Load environment variables diff --git a/playground/demos/optimize_llm_stack/vortex.py b/playground/demos/optimize_llm_stack/vortex.py index 438c1451..a40c29b9 100644 --- a/playground/demos/optimize_llm_stack/vortex.py +++ b/playground/demos/optimize_llm_stack/vortex.py @@ -1,5 +1,4 @@ import os -import subprocess from dotenv import load_dotenv diff --git a/playground/demos/swarm_of_mma_manufacturing/main.py b/playground/demos/swarm_of_mma_manufacturing/main.py index 37938608..05b0e8e5 100644 --- a/playground/demos/swarm_of_mma_manufacturing/main.py +++ b/playground/demos/swarm_of_mma_manufacturing/main.py @@ -20,7 +20,6 @@ from termcolor import colored from swarms.models import GPT4VisionAPI from swarms.structs import Agent -from swarms.utils.phoenix_handler import phoenix_trace_decorator load_dotenv() api_key = os.getenv("OPENAI_API_KEY") diff --git a/playground/models/kosmos2.py b/playground/models/kosmos2.py index ce39a710..6fc4df02 100644 --- a/playground/models/kosmos2.py +++ b/playground/models/kosmos2.py @@ -1,4 +1,4 @@ -from swarms.models.kosmos2 import Kosmos2, Detections +from swarms.models.kosmos2 import Kosmos2 from PIL import Image diff --git a/playground/tools/agent_with_tools.py b/playground/tools/agent_with_tools.py index ee4a8ef7..3bad0b1d 100644 --- a/playground/tools/agent_with_tools.py +++ b/playground/tools/agent_with_tools.py @@ -1,7 +1,6 @@ import os from swarms.models import OpenAIChat from swarms.structs import Agent -from swarms.tools.tool import tool from dotenv import load_dotenv load_dotenv() diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py index 4f92880a..71a7871d 100644 --- a/swarms/memory/__init__.py +++ b/swarms/memory/__init__.py @@ -1,4 +1,5 @@ from swarms.memory.base_vectordb import VectorDatabase from swarms.memory.short_term_memory import ShortTermMemory +from swarms.memory.sqlite import SQLiteDB -__all__ = ["VectorDatabase", "ShortTermMemory"] +__all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB"] diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index 83ff5593..40f9979c 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -82,7 +82,7 @@ class Qdrant: f"Collection '{self.collection_name}' already" " exists." ) - except Exception as e: + except Exception: self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( diff --git a/swarms/memory/short_term_memory.py b/swarms/memory/short_term_memory.py index 53daf332..d380fba5 100644 --- a/swarms/memory/short_term_memory.py +++ b/swarms/memory/short_term_memory.py @@ -12,8 +12,8 @@ class ShortTermMemory(BaseStructure): autosave (bool, optional): _description_. Defaults to True. *args: _description_ **kwargs: _description_ - - + + Example: >>> from swarms.memory.short_term_memory import ShortTermMemory >>> stm = ShortTermMemory() @@ -22,9 +22,10 @@ class ShortTermMemory(BaseStructure): >>> stm.add(role="agent", message="I am fine.") >>> stm.add(role="agent", message="How are you?") >>> stm.add(role="agent", message="I am fine.") - - + + """ + def __init__( self, return_str: bool = True, @@ -93,7 +94,7 @@ class ShortTermMemory(BaseStructure): index (_type_): _description_ role (str): _description_ message (str): _description_ - + """ self.short_term_memory[index] = { "role": role, diff --git a/swarms/memory/sqlite.py b/swarms/memory/sqlite.py new file mode 100644 index 00000000..eed4ee2c --- /dev/null +++ b/swarms/memory/sqlite.py @@ -0,0 +1,120 @@ +from typing import List, Tuple, Any, Optional +from swarms.memory.base_vectordb import VectorDatabase + +try: + import sqlite3 +except ImportError: + raise ImportError( + "Please install sqlite3 to use the SQLiteDB class." + ) + + +class SQLiteDB(VectorDatabase): + """ + A reusable class for SQLite database operations with methods for adding, + deleting, updating, and querying data. + + Attributes: + db_path (str): The file path to the SQLite database. + """ + + def __init__(self, db_path: str): + """ + Initializes the SQLiteDB class with the given database path. + + Args: + db_path (str): The file path to the SQLite database. + """ + self.db_path = db_path + + def execute_query( + self, query: str, params: Optional[Tuple[Any, ...]] = None + ) -> List[Tuple]: + """ + Executes a SQL query and returns fetched results. + + Args: + query (str): The SQL query to execute. + params (Tuple[Any, ...], optional): The parameters to substitute into the query. + + Returns: + List[Tuple]: The results fetched from the database. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params or ()) + return cursor.fetchall() + except Exception as error: + print(f"Error executing query: {error}") + raise error + + def add(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Adds a new entry to the database. + + Args: + query (str): The SQL query for insertion. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error adding new entry: {error}") + raise error + + def delete(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Deletes an entry from the database. + + Args: + query (str): The SQL query for deletion. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error deleting entry: {error}") + raise error + + def update(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Updates an entry in the database. + + Args: + query (str): The SQL query for updating. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error updating entry: {error}") + raise error + + def query( + self, query: str, params: Optional[Tuple[Any, ...]] = None + ) -> List[Tuple]: + """ + Fetches data from the database based on a query. + + Args: + query (str): The SQL query to execute. + params (Tuple[Any, ...], optional): The parameters to substitute into the query. + + Returns: + List[Tuple]: The results fetched from the database. + """ + try: + return self.execute_query(query, params) + except Exception as error: + print(f"Error querying database: {error}") + raise error diff --git a/swarms/models/base_tts.py b/swarms/models/base_tts.py index 0faaf6ff..60896856 100644 --- a/swarms/models/base_tts.py +++ b/swarms/models/base_tts.py @@ -1,7 +1,7 @@ import wave from typing import Optional from swarms.models.base_llm import AbstractLLM -from abc import ABC, abstractmethod +from abc import abstractmethod class BaseTTSModel(AbstractLLM): diff --git a/swarms/models/gemini.py b/swarms/models/gemini.py index 8cb09ca5..d12ea7d9 100644 --- a/swarms/models/gemini.py +++ b/swarms/models/gemini.py @@ -174,7 +174,10 @@ class Gemini(BaseMultiModalModel): return response.text else: response = self.model.generate_content( - prepare_prompt, stream=self.stream, *args, **kwargs + prepare_prompt, + stream=self.stream, + *args, + **kwargs, ) return response.text except Exception as error: diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index e9a599d0..e96995c4 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -1,4 +1,3 @@ -import os from openai import OpenAI client = OpenAI() diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 9d48791e..be5c7121 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -12,7 +12,6 @@ from termcolor import colored from swarms.memory.base_vectordb import VectorDatabase from swarms.prompts.agent_system_prompts import ( - FLOW_SYSTEM_PROMPT, AGENT_SYSTEM_PROMPT_3, agent_system_prompt_2, ) diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index 1cb31333..6f07d0d3 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -3,8 +3,6 @@ import queue import threading from time import sleep from typing import Callable, Dict, List, Optional -import asyncio -import concurrent.futures from termcolor import colored diff --git a/swarms/structs/base.py b/swarms/structs/base.py index 7d365b23..adfa974d 100644 --- a/swarms/structs/base.py +++ b/swarms/structs/base.py @@ -1,6 +1,6 @@ import json import os -from abc import ABC, abstractmethod +from abc import ABC from typing import Optional, Any, Dict, List from datetime import datetime import asyncio diff --git a/swarms/utils/apa.py b/swarms/utils/apa.py index f2e1bb38..fa73b7b4 100644 --- a/swarms/utils/apa.py +++ b/swarms/utils/apa.py @@ -1,7 +1,5 @@ from enum import Enum, unique, auto import abc -import hashlib -import re from typing import List, Optional import json from dataclasses import dataclass, field diff --git a/swarms/utils/loggers.py b/swarms/utils/loggers.py index a0dec94d..68477132 100644 --- a/swarms/utils/loggers.py +++ b/swarms/utils/loggers.py @@ -487,21 +487,21 @@ def print_action_base(action: Action): """ if action.content != "": logger.typewriter_log( - f"content:", Fore.YELLOW, f"{action.content}" + "content:", Fore.YELLOW, f"{action.content}" ) logger.typewriter_log( - f"Thought:", Fore.YELLOW, f"{action.thought}" + "Thought:", Fore.YELLOW, f"{action.thought}" ) if len(action.plan) > 0: logger.typewriter_log( - f"Plan:", + "Plan:", Fore.YELLOW, ) for line in action.plan: line = line.lstrip("- ") logger.typewriter_log("- ", Fore.GREEN, line.strip()) logger.typewriter_log( - f"Criticism:", Fore.YELLOW, f"{action.criticism}" + "Criticism:", Fore.YELLOW, f"{action.criticism}" ) @@ -515,15 +515,15 @@ def print_action_tool(action: Action): Returns: None """ - logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}") + logger.typewriter_log("Tool:", Fore.BLUE, f"{action.tool_name}") logger.typewriter_log( - f"Tool Input:", Fore.BLUE, f"{action.tool_input}" + "Tool Input:", Fore.BLUE, f"{action.tool_input}" ) output = ( action.tool_output if action.tool_output != "" else "None" ) - logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}") + logger.typewriter_log("Tool Output:", Fore.BLUE, f"{output}") color = Fore.RED if action.tool_output_status == ToolCallStatus.ToolCallSuccess: @@ -534,7 +534,7 @@ def print_action_tool(action: Action): color = Fore.YELLOW logger.typewriter_log( - f"Tool Call Status:", + "Tool Call Status:", Fore.BLUE, f"{color}{action.tool_output_status.name}{Style.RESET_ALL}", ) diff --git a/swarms/utils/main.py b/swarms/utils/main.py index c9c0f380..b94fae11 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -9,7 +9,6 @@ from typing import Dict import boto3 import numpy as np -import pandas as pd import requests diff --git a/swarms/utils/pdf_to_text.py b/swarms/utils/pdf_to_text.py index 35309dd3..6d589ad5 100644 --- a/swarms/utils/pdf_to_text.py +++ b/swarms/utils/pdf_to_text.py @@ -1,5 +1,4 @@ import sys -import os try: import PyPDF2 diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index 903c3a0e..0b66b749 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -1,4 +1,3 @@ -import pytest from swarms.memory.short_term_memory import ShortTermMemory import threading diff --git a/tests/memory/test_sqlite.py b/tests/memory/test_sqlite.py new file mode 100644 index 00000000..6b4213b0 --- /dev/null +++ b/tests/memory/test_sqlite.py @@ -0,0 +1,104 @@ +import pytest +import sqlite3 +from swarms.memory.sqlite import SQLiteDB + + +@pytest.fixture +def db(): + conn = sqlite3.connect(":memory:") + conn.execute( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + ) + conn.commit() + return SQLiteDB(":memory:") + + +def test_add(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.query("SELECT * FROM test") + assert result == [(1, "test")] + + +def test_delete(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + db.delete("DELETE FROM test WHERE name = ?", ("test",)) + result = db.query("SELECT * FROM test") + assert result == [] + + +def test_update(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + db.update( + "UPDATE test SET name = ? WHERE name = ?", ("new", "test") + ) + result = db.query("SELECT * FROM test") + assert result == [(1, "new")] + + +def test_query(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.query("SELECT * FROM test WHERE name = ?", ("test",)) + assert result == [(1, "test")] + + +def test_execute_query(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.execute_query( + "SELECT * FROM test WHERE name = ?", ("test",) + ) + assert result == [(1, "test")] + + +def test_add_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.add("INSERT INTO test (name) VALUES (?)") + + +def test_delete_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.delete("DELETE FROM test WHERE name = ?") + + +def test_update_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.update("UPDATE test SET name = ? WHERE name = ?") + + +def test_query_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.query("SELECT * FROM test WHERE name = ?") + + +def test_execute_query_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.execute_query("SELECT * FROM test WHERE name = ?") + + +def test_add_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.add("INSERT INTO wrong (name) VALUES (?)", ("test",)) + + +def test_delete_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.delete("DELETE FROM wrong WHERE name = ?", ("test",)) + + +def test_update_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.update( + "UPDATE wrong SET name = ? WHERE name = ?", + ("new", "test"), + ) + + +def test_query_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.query("SELECT * FROM wrong WHERE name = ?", ("test",)) + + +def test_execute_query_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.execute_query( + "SELECT * FROM wrong WHERE name = ?", ("test",) + ) diff --git a/tests/models/test_bioclip.py b/tests/models/test_bioclip.py index 99e1e343..1e07df6d 100644 --- a/tests/models/test_bioclip.py +++ b/tests/models/test_bioclip.py @@ -1,5 +1,4 @@ # Import necessary modules and define fixtures if needed -import os import pytest import torch from PIL import Image @@ -156,7 +155,6 @@ def test_clip_inference_performance( def test_clip_preprocessing_pipelines( clip_instance, sample_image_path ): - labels = ["label1", "label2"] image = Image.open(sample_image_path) # Test preprocessing for training diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index c6f3e023..2a1d4ad4 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -110,7 +110,7 @@ def test_gemini_init_missing_api_key(): with pytest.raises( ValueError, match="Please provide a Gemini API key" ): - model = Gemini(gemini_api_key=None) + Gemini(gemini_api_key=None) # Test Gemini initialization with missing model name @@ -118,7 +118,7 @@ def test_gemini_init_missing_model_name(): with pytest.raises( ValueError, match="Please provide a model name" ): - model = Gemini(model_name=None) + Gemini(model_name=None) # Test Gemini run method with empty task diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py index dfd03e27..c7758a36 100644 --- a/tests/models/test_gpt4_vision_api.py +++ b/tests/models/test_gpt4_vision_api.py @@ -48,7 +48,7 @@ def test_run_success(vision_api): def test_run_request_error(vision_api): with patch( "requests.post", side_effect=RequestException("Request Error") - ) as mock_post: + ): with pytest.raises(RequestException): vision_api.run("What is this?", img) @@ -58,7 +58,7 @@ def test_run_response_error(vision_api): with patch( "requests.post", return_value=Mock(json=lambda: expected_response), - ) as mock_post: + ): with pytest.raises(RuntimeError): vision_api.run("What is this?", img) @@ -153,7 +153,7 @@ async def test_arun_request_error(vision_api): "aiohttp.ClientSession.post", new_callable=AsyncMock, side_effect=Exception("Request Error"), - ) as mock_post: + ): with pytest.raises(Exception): await vision_api.arun("What is this?", img) @@ -181,7 +181,7 @@ def test_run_many_success(vision_api): def test_run_many_request_error(vision_api): with patch( "requests.post", side_effect=RequestException("Request Error") - ) as mock_post: + ): tasks = ["What is this?", "What is that?"] imgs = [img, img] with pytest.raises(RequestException): @@ -196,7 +196,7 @@ async def test_arun_json_decode_error(vision_api): return_value=AsyncMock( json=AsyncMock(side_effect=ValueError) ), - ) as mock_post: + ): with pytest.raises(ValueError): await vision_api.arun("What is this?", img) @@ -210,7 +210,7 @@ async def test_arun_api_error(vision_api): return_value=AsyncMock( json=AsyncMock(return_value=error_response) ), - ) as mock_post: + ): with pytest.raises(Exception, match="API Error"): await vision_api.arun("What is this?", img) @@ -224,7 +224,7 @@ async def test_arun_unexpected_response(vision_api): return_value=AsyncMock( json=AsyncMock(return_value=unexpected_response) ), - ) as mock_post: + ): with pytest.raises(Exception, match="Unexpected response"): await vision_api.arun("What is this?", img) @@ -247,6 +247,6 @@ async def test_arun_timeout(vision_api): "aiohttp.ClientSession.post", new_callable=AsyncMock, side_effect=asyncio.TimeoutError, - ) as mock_post: + ): with pytest.raises(asyncio.TimeoutError): await vision_api.arun("What is this?", img) diff --git a/tests/models/test_multion.py b/tests/models/test_multion.py index 416e6dc3..cc91b421 100644 --- a/tests/models/test_multion.py +++ b/tests/models/test_multion.py @@ -15,7 +15,7 @@ def mock_multion(): def test_multion_import(): with pytest.raises(ImportError): - import multion + pass def test_multion_init(): diff --git a/tests/models/test_togther.py b/tests/models/test_togther.py index 75313a45..c28e69ae 100644 --- a/tests/models/test_togther.py +++ b/tests/models/test_togther.py @@ -1,4 +1,3 @@ -import os import requests import pytest from unittest.mock import patch, Mock diff --git a/tests/models/test_vllm.py b/tests/models/test_vllm.py index d15a13b9..6eec8f27 100644 --- a/tests/models/test_vllm.py +++ b/tests/models/test_vllm.py @@ -113,7 +113,7 @@ def test_vllm_run_empty_task(vllm_instance): # Test initialization with invalid parameters def test_vllm_invalid_init(): with pytest.raises(ValueError): - vllm_instance = vLLM( + vLLM( model_name=None, tensor_parallel_size=-1, trust_remote_code="invalid", diff --git a/tests/structs/test_agent.py b/tests/structs/test_agent.py index a8e1cf92..8e5b11be 100644 --- a/tests/structs/test_agent.py +++ b/tests/structs/test_agent.py @@ -347,7 +347,7 @@ def test_flow_response_filtering(flow_instance): def test_flow_undo_last(flow_instance): # Test the undo functionality response1 = flow_instance.run("Task 1") - response2 = flow_instance.run("Task 2") + flow_instance.run("Task 2") previous_state, message = flow_instance.undo_last() assert response1 == previous_state assert "Restored to" in message @@ -577,7 +577,7 @@ def test_flow_rollback(flow_instance): # Test rolling back to a previous state state1 = flow_instance.get_state() flow_instance.change_prompt("New prompt") - state2 = flow_instance.get_state() + flow_instance.get_state() flow_instance.rollback_to_state(state1) assert ( flow_instance.get_current_prompt() == state1["current_prompt"] diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py index 92d013b7..f3b9fefa 100644 --- a/tests/structs/test_autoscaler.py +++ b/tests/structs/test_autoscaler.py @@ -23,18 +23,18 @@ def test_autoscaler_init(): assert autoscaler.scale_up_factor == 1 assert autoscaler.idle_threshold == 0.2 assert autoscaler.busy_threshold == 0.7 - assert autoscaler.autoscale == True + assert autoscaler.autoscale is True assert autoscaler.min_agents == 1 assert autoscaler.max_agents == 5 - assert autoscaler.custom_scale_strategy == None + assert autoscaler.custom_scale_strategy is None assert len(autoscaler.agents_pool) == 5 - assert autoscaler.task_queue.empty() == True + assert autoscaler.task_queue.empty() is True def test_autoscaler_add_task(): autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler.add_task("task1") - assert autoscaler.task_queue.empty() == False + assert autoscaler.task_queue.empty() is False def test_autoscaler_run(): @@ -75,7 +75,7 @@ def test_autoscaler_get_agent_by_id(): def test_autoscaler_get_agent_by_id_not_found(): autoscaler = AutoScaler(initial_agents=5, agent=agent) agent = autoscaler.get_agent_by_id("fake_id") - assert agent == None + assert agent is None @patch("swarms.swarms.Agent.is_healthy") diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py index 2c116402..fada564a 100644 --- a/tests/structs/test_task.py +++ b/tests/structs/test_task.py @@ -1,4 +1,3 @@ -import os from unittest.mock import Mock import pytest @@ -163,4 +162,4 @@ def test_execute(): agent = Agent() task = Task(id="5", task="Task5", result=None, agents=[agent]) # Assuming execute method returns True on successful execution - assert task.execute() == True + assert task.execute() is True diff --git a/tests/swarms/test_multi_agent_collab.py b/tests/swarms/test_multi_agent_collab.py index e30358aa..f56170e8 100644 --- a/tests/swarms/test_multi_agent_collab.py +++ b/tests/swarms/test_multi_agent_collab.py @@ -26,7 +26,7 @@ def test_collaboration_initialization(collaboration): assert callable(collaboration.select_next_speaker) assert collaboration.max_iters == 10 assert collaboration.results == [] - assert collaboration.logging == True + assert collaboration.logging is True def test_reset(collaboration): diff --git a/tests/upload_tests_to_issues.py b/tests/upload_tests_to_issues.py index 864fee29..15de1245 100644 --- a/tests/upload_tests_to_issues.py +++ b/tests/upload_tests_to_issues.py @@ -1,7 +1,5 @@ import os import subprocess -import json -import re import requests from dotenv import load_dotenv diff --git a/tests/utils/test_phoenix_handler.py b/tests/utils/test_phoenix_handler.py index 3b6915b9..a7dc8898 100644 --- a/tests/utils/test_phoenix_handler.py +++ b/tests/utils/test_phoenix_handler.py @@ -1,8 +1,6 @@ # Import necessary modules and functions for testing -import functools import subprocess import sys -import traceback import pytest