From 36c295ed8ba722ce741a2164211b788f112a5255 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Fri, 1 Mar 2024 16:21:28 -0800 Subject: [PATCH] update LangChain imports & run code_quality check --- .gitignore | 2 +- swarms/agents/worker_agent.py | 6 +- swarms/memory/lanchain_chroma.py | 2 +- swarms/models/openai_models.py | 10 +- tests/agents/test_multion.py | 24 +- tests/agents/test_tool_agent.py | 81 +++-- tests/memory/test_dictinternalmemory.py | 5 +- tests/memory/test_dictsharedmemory.py | 26 +- .../test_langchainchromavectormemory.py | 30 +- tests/memory/test_pinecone.py | 6 +- tests/memory/test_pq_db.py | 9 +- tests/memory/test_qdrant.py | 12 +- tests/memory/test_short_term_memory.py | 64 ++-- tests/memory/test_sqlite.py | 16 +- tests/memory/test_weaviate.py | 103 +++--- tests/models/test_anthropic.py | 65 ++-- tests/models/test_biogpt.py | 33 +- tests/models/test_cohere.py | 285 +++++----------- tests/models/test_elevenlab.py | 46 +-- tests/models/test_fire_function_caller.py | 4 +- tests/models/test_fuyu.py | 37 +- tests/models/test_gemini.py | 63 ++-- tests/models/test_gigabind.py | 59 ++-- tests/models/test_gpt4_vision_api.py | 118 +++---- tests/models/test_hf.py | 8 +- tests/models/test_hf_pipeline.py | 16 +- tests/models/test_huggingface.py | 35 +- tests/models/test_idefics.py | 88 ++--- tests/models/test_kosmos.py | 41 +-- tests/models/test_llama_function_caller.py | 16 +- tests/models/test_mixtral.py | 16 +- tests/models/test_mpt7b.py | 16 +- tests/models/test_nougat.py | 44 +-- tests/models/test_qwen.py | 40 +-- tests/models/test_speech_t5.py | 46 +-- tests/models/test_ssd_1b.py | 23 +- tests/models/test_timm.py | 10 +- tests/models/test_timm_model.py | 7 +- tests/models/test_togther.py | 19 +- tests/models/test_ultralytics.py | 4 +- tests/models/test_vilt.py | 12 +- tests/models/test_yi_200k.py | 45 +-- tests/models/test_zeroscope.py | 43 +-- tests/structs/test_agent.py | 321 ++++++------------ tests/structs/test_autoscaler.py | 8 +- tests/structs/test_base.py | 59 ++-- tests/structs/test_base_workflow.py | 9 +- tests/structs/test_concurrent_workflow.py | 4 +- tests/structs/test_conversation.py | 19 +- tests/structs/test_groupchat.py | 27 +- tests/structs/test_json.py | 7 +- tests/structs/test_majority_voting.py | 36 +- tests/structs/test_message_pool.py | 80 ++--- tests/structs/test_model_parallizer.py | 59 ++-- tests/structs/test_multi_agent_collab.py | 13 +- tests/structs/test_nonlinear_workflow.py | 5 +- tests/structs/test_recursive_workflow.py | 4 +- tests/structs/test_sequential_workflow.py | 9 +- tests/structs/test_swarmnetwork.py | 4 +- tests/structs/test_task.py | 39 +-- tests/structs/test_taskqueuebase.py | 7 +- tests/structs/test_team.py | 12 +- tests/structs/test_tests_graph_workflow.py | 20 +- tests/telemetry/test_posthog_utils.py | 5 +- tests/telemetry/test_user_utils.py | 8 +- tests/test_upload_tests_to_issues.py | 8 +- tests/tokenizers/test_anthropictokenizer.py | 5 +- tests/tokenizers/test_basetokenizer.py | 8 +- tests/tokenizers/test_huggingfacetokenizer.py | 16 +- tests/tokenizers/test_openaitokenizer.py | 24 +- tests/tokenizers/test_tokenizer.py | 53 +-- tests/tools/test_tools_base.py | 77 +++-- tests/utils/test_check_device.py | 21 +- tests/utils/test_class_args_wrapper.py | 7 +- tests/utils/test_device.py | 46 +-- tests/utils/test_display_markdown_message.py | 6 +- .../utils/test_extract_code_from_markdown.py | 15 +- tests/utils/test_find_image_path.py | 12 +- tests/utils/test_limit_tokens_from_string.py | 28 +- tests/utils/test_load_model_torch.py | 32 +- tests/utils/test_load_models_torch.py | 10 +- tests/utils/test_math_eval.py | 2 + tests/utils/test_metrics_decorator.py | 14 +- tests/utils/test_pdf_to_text.py | 4 +- tests/utils/test_prep_torch_inference.py | 9 +- .../utils/test_prep_torch_model_inference.py | 3 +- tests/utils/test_print_class_parameters.py | 24 +- .../utils/test_subprocess_code_interpreter.py | 23 +- tests/utils/test_try_except_wrapper.py | 22 +- 89 files changed, 1074 insertions(+), 1785 deletions(-) diff --git a/.gitignore b/.gitignore index d9d1aa3b..e24110aa 100644 --- a/.gitignore +++ b/.gitignore @@ -110,7 +110,7 @@ docs/_build/ # PyBuilder .pybuilder/ target/ - +` # Jupyter Notebook .ipynb_checkpoints diff --git a/swarms/agents/worker_agent.py b/swarms/agents/worker_agent.py index c0e7f464..6caf6088 100644 --- a/swarms/agents/worker_agent.py +++ b/swarms/agents/worker_agent.py @@ -2,9 +2,9 @@ import os from typing import List import faiss -from langchain.docstore import InMemoryDocstore -from langchain.embeddings import OpenAIEmbeddings -from langchain.vectorstores import FAISS +from langchain_community.docstore import InMemoryDocstore +from langchain_community.embeddings import OpenAIEmbeddings +from langchain_community.vectorstores import FAISS from langchain_experimental.autonomous_agents import AutoGPT from swarms.tools.tool import BaseTool diff --git a/swarms/memory/lanchain_chroma.py b/swarms/memory/lanchain_chroma.py index 95a2e9e3..cf846d31 100644 --- a/swarms/memory/lanchain_chroma.py +++ b/swarms/memory/lanchain_chroma.py @@ -5,7 +5,7 @@ from langchain.chains import RetrievalQA from langchain.chains.question_answering import load_qa_chain from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter -from langchain.vectorstores import Chroma +from langchain_community.vectorstores import Chroma from swarms.models.openai_models import OpenAIChat diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 635cd586..90f76734 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -647,7 +647,10 @@ class BaseOpenAI(BaseLLM): if self.openai_proxy: import openai - openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 + openai.proxy = { + "http": self.openai_proxy, + "https": self.openai_proxy, + } # type: ignore[assignment] # noqa: E501 return {**openai_creds, **self._default_params} @property @@ -963,7 +966,10 @@ class OpenAIChat(BaseLLM): if openai_organization: openai.organization = openai_organization if openai_proxy: - openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 + openai.proxy = { + "http": openai_proxy, + "https": openai_proxy, + } # type: ignore[assignment] # noqa: E501 except ImportError: raise ImportError( "Could not import openai python package. " diff --git a/tests/agents/test_multion.py b/tests/agents/test_multion.py index 23614934..66399676 100644 --- a/tests/agents/test_multion.py +++ b/tests/agents/test_multion.py @@ -23,19 +23,15 @@ def test_multion_agent_run(mock_multion): assert result == "result" assert status == "status" assert last_url == "lastUrl" - mock_multion.browse.assert_called_once_with( - { - "cmd": "task", - "url": "https://www.example.com", - "maxSteps": 5, - } - ) + mock_multion.browse.assert_called_once_with({ + "cmd": "task", + "url": "https://www.example.com", + "maxSteps": 5, + }) # Additional tests for different tasks -@pytest.mark.parametrize( - "task", ["task1", "task2", "task3", "task4", "task5"] -) +@pytest.mark.parametrize("task", ["task1", "task2", "task3", "task4", "task5"]) @patch("swarms.agents.multion_agent.multion") def test_multion_agent_run_different_tasks(mock_multion, task): mock_response = MagicMock() @@ -54,6 +50,8 @@ def test_multion_agent_run_different_tasks(mock_multion, task): assert result == "result" assert status == "status" assert last_url == "lastUrl" - mock_multion.browse.assert_called_once_with( - {"cmd": task, "url": "https://www.example.com", "maxSteps": 5} - ) + mock_multion.browse.assert_called_once_with({ + "cmd": task, + "url": "https://www.example.com", + "maxSteps": 5 + }) diff --git a/tests/agents/test_tool_agent.py b/tests/agents/test_tool_agent.py index 691489c0..4ee71a22 100644 --- a/tests/agents/test_tool_agent.py +++ b/tests/agents/test_tool_agent.py @@ -11,18 +11,27 @@ def test_tool_agent_init(): json_schema = { "type": "object", "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - "is_student": {"type": "boolean"}, - "courses": {"type": "array", "items": {"type": "string"}}, + "name": { + "type": "string" + }, + "age": { + "type": "number" + }, + "is_student": { + "type": "boolean" + }, + "courses": { + "type": "array", + "items": { + "type": "string" + } + }, }, } name = "Test Agent" description = "This is a test agent" - agent = ToolAgent( - name, description, model, tokenizer, json_schema - ) + agent = ToolAgent(name, description, model, tokenizer, json_schema) assert agent.name == name assert agent.description == description @@ -38,22 +47,29 @@ def test_tool_agent_run(mock_run): json_schema = { "type": "object", "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - "is_student": {"type": "boolean"}, - "courses": {"type": "array", "items": {"type": "string"}}, + "name": { + "type": "string" + }, + "age": { + "type": "number" + }, + "is_student": { + "type": "boolean" + }, + "courses": { + "type": "array", + "items": { + "type": "string" + } + }, }, } name = "Test Agent" description = "This is a test agent" - task = ( - "Generate a person's information based on the following" - " schema:" - ) + task = ("Generate a person's information based on the following" + " schema:") - agent = ToolAgent( - name, description, model, tokenizer, json_schema - ) + agent = ToolAgent(name, description, model, tokenizer, json_schema) agent.run(task) mock_run.assert_called_once_with(task) @@ -65,10 +81,21 @@ def test_tool_agent_init_with_kwargs(): json_schema = { "type": "object", "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - "is_student": {"type": "boolean"}, - "courses": {"type": "array", "items": {"type": "string"}}, + "name": { + "type": "string" + }, + "age": { + "type": "number" + }, + "is_student": { + "type": "boolean" + }, + "courses": { + "type": "array", + "items": { + "type": "string" + } + }, }, } name = "Test Agent" @@ -82,9 +109,8 @@ def test_tool_agent_init_with_kwargs(): "max_string_token_length": 20, } - agent = ToolAgent( - name, description, model, tokenizer, json_schema, **kwargs - ) + agent = ToolAgent(name, description, model, tokenizer, json_schema, + **kwargs) assert agent.name == name assert agent.description == description @@ -95,7 +121,4 @@ def test_tool_agent_init_with_kwargs(): assert agent.max_array_length == kwargs["max_array_length"] assert agent.max_number_tokens == kwargs["max_number_tokens"] assert agent.temperature == kwargs["temperature"] - assert ( - agent.max_string_token_length - == kwargs["max_string_token_length"] - ) + assert (agent.max_string_token_length == kwargs["max_string_token_length"]) diff --git a/tests/memory/test_dictinternalmemory.py b/tests/memory/test_dictinternalmemory.py index 7658eb7c..409acbf2 100644 --- a/tests/memory/test_dictinternalmemory.py +++ b/tests/memory/test_dictinternalmemory.py @@ -33,9 +33,8 @@ def test_memory_limit_enforced(memory): # Parameterized Tests -@pytest.mark.parametrize( - "scores, best_score", [([10, 5, 3], 10), ([1, 2, 3], 3)] -) +@pytest.mark.parametrize("scores, best_score", [([10, 5, 3], 10), + ([1, 2, 3], 3)]) def test_get_top_n(scores, best_score, memory): for score in scores: memory.add(score, {"data": f"test{score}"}) diff --git a/tests/memory/test_dictsharedmemory.py b/tests/memory/test_dictsharedmemory.py index a41ccd8f..f391cdc2 100644 --- a/tests/memory/test_dictsharedmemory.py +++ b/tests/memory/test_dictsharedmemory.py @@ -26,8 +26,7 @@ def memory_instance(memory_file): def test_init(memory_file): memory = DictSharedMemory(file_loc=memory_file) assert os.path.exists( - memory.file_loc - ), "Memory file should be created if non-existent" + memory.file_loc), "Memory file should be created if non-existent" def test_add_entry(memory_instance): @@ -44,9 +43,8 @@ def test_get_top_n(memory_instance): memory_instance.add(9.5, "agent123", 1, "Entry A") memory_instance.add(8.5, "agent124", 1, "Entry B") top_1 = memory_instance.get_top_n(1) - assert ( - len(top_1) == 1 - ), "get_top_n should return the correct number of top entries" + assert (len(top_1) == 1 + ), "get_top_n should return the correct number of top entries" # Parameterized tests @@ -59,18 +57,14 @@ def test_get_top_n(memory_instance): # add more test cases ], ) -def test_parametrized_get_top_n( - memory_instance, scores, agent_ids, expected_top_score -): +def test_parametrized_get_top_n(memory_instance, scores, agent_ids, + expected_top_score): for score, agent_id in zip(scores, agent_ids): - memory_instance.add( - score, agent_id, 1, f"Entry by {agent_id}" - ) + memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}") top_1 = memory_instance.get_top_n(1) top_score = next(iter(top_1.values()))["score"] - assert ( - top_score == expected_top_score - ), "get_top_n should return the entry with top score" + assert (top_score == expected_top_score + ), "get_top_n should return the entry with top score" # Exception testing @@ -78,9 +72,7 @@ def test_parametrized_get_top_n( def test_add_entry_invalid_input(memory_instance): with pytest.raises(ValueError): - memory_instance.add( - "invalid_score", "agent123", 1, "Test Entry" - ) + memory_instance.add("invalid_score", "agent123", 1, "Test Entry") # Mocks and monkey-patching diff --git a/tests/memory/test_langchainchromavectormemory.py b/tests/memory/test_langchainchromavectormemory.py index ee882c6c..25316ae2 100644 --- a/tests/memory/test_langchainchromavectormemory.py +++ b/tests/memory/test_langchainchromavectormemory.py @@ -35,16 +35,13 @@ def qa_mock(): # Example test cases def test_initialization_default_settings(vector_memory): assert vector_memory.chunk_size == 1000 - assert ( - vector_memory.chunk_overlap == 100 - ) # assuming default overlap of 0.1 + assert (vector_memory.chunk_overlap == 100 + ) # assuming default overlap of 0.1 assert vector_memory.loc.exists() def test_add_entry(vector_memory, embeddings_mock): - with patch.object( - vector_memory.db, "add_texts" - ) as add_texts_mock: + with patch.object(vector_memory.db, "add_texts") as add_texts_mock: vector_memory.add("Example text") add_texts_mock.assert_called() @@ -77,20 +74,17 @@ def test_ask_question_returns_string(vector_memory, qa_mock): ), # Mocked object as a placeholder ], ) -def test_search_memory_different_params( - vector_memory, query, k, type, expected -): +def test_search_memory_different_params(vector_memory, query, k, type, + expected): with patch.object( - vector_memory.db, - "max_marginal_relevance_search", - return_value=expected, - ): - with patch.object( vector_memory.db, - "similarity_search_with_score", + "max_marginal_relevance_search", return_value=expected, + ): + with patch.object( + vector_memory.db, + "similarity_search_with_score", + return_value=expected, ): - result = vector_memory.search_memory( - query, k=k, type=type - ) + result = vector_memory.search_memory(query, k=k, type=type) assert len(result) == (k if k > 0 else 0) diff --git a/tests/memory/test_pinecone.py b/tests/memory/test_pinecone.py index a7d4fcea..bedce998 100644 --- a/tests/memory/test_pinecone.py +++ b/tests/memory/test_pinecone.py @@ -8,8 +8,7 @@ api_key = os.getenv("PINECONE_API_KEY") or "" def test_init(): with patch("pinecone.init") as MockInit, patch( - "pinecone.Index" - ) as MockIndex: + "pinecone.Index") as MockIndex: store = PineconeDB( api_key=api_key, index_name="test_index", @@ -71,8 +70,7 @@ def test_query(): def test_create_index(): with patch("pinecone.init"), patch("pinecone.Index"), patch( - "pinecone.create_index" - ) as MockCreateIndex: + "pinecone.create_index") as MockCreateIndex: store = PineconeDB( api_key=api_key, index_name="test_index", diff --git a/tests/memory/test_pq_db.py b/tests/memory/test_pq_db.py index 5e44f0ba..e3e0925c 100644 --- a/tests/memory/test_pq_db.py +++ b/tests/memory/test_pq_db.py @@ -32,8 +32,7 @@ def test_create_vector_model(): def test_add_or_update_vector(): with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session" - ) as MockSession: + "sqlalchemy.orm.Session") as MockSession: db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", @@ -51,8 +50,7 @@ def test_add_or_update_vector(): def test_query_vectors(): with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session" - ) as MockSession: + "sqlalchemy.orm.Session") as MockSession: db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", @@ -67,8 +65,7 @@ def test_query_vectors(): def test_delete_vector(): with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session" - ) as MockSession: + "sqlalchemy.orm.Session") as MockSession: db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", diff --git a/tests/memory/test_qdrant.py b/tests/memory/test_qdrant.py index 5f82814c..f6023dd6 100644 --- a/tests/memory/test_qdrant.py +++ b/tests/memory/test_qdrant.py @@ -13,9 +13,8 @@ def mock_qdrant_client(): @pytest.fixture def mock_sentence_transformer(): - with patch( - "sentence_transformers.SentenceTransformer" - ) as MockSentenceTransformer: + with patch("sentence_transformers.SentenceTransformer" + ) as MockSentenceTransformer: yield MockSentenceTransformer() @@ -29,9 +28,7 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client): assert qdrant_client.client is not None -def test_load_embedding_model( - qdrant_client, mock_sentence_transformer -): +def test_load_embedding_model(qdrant_client, mock_sentence_transformer): qdrant_client._load_embedding_model("model_name") mock_sentence_transformer.assert_called_once_with("model_name") @@ -39,8 +36,7 @@ def test_load_embedding_model( def test_setup_collection(qdrant_client, mock_qdrant_client): qdrant_client._setup_collection() mock_qdrant_client.get_collection.assert_called_once_with( - qdrant_client.collection_name - ) + qdrant_client.collection_name) def test_add_vectors(qdrant_client, mock_qdrant_client): diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index 132da5f6..2dbd9fc9 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -12,26 +12,29 @@ def test_init(): def test_add(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert memory.short_term_memory == [ - {"role": "user", "message": "Hello, world!"} - ] + assert memory.short_term_memory == [{ + "role": "user", + "message": "Hello, world!" + }] def test_get_short_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert memory.get_short_term() == [ - {"role": "user", "message": "Hello, world!"} - ] + assert memory.get_short_term() == [{ + "role": "user", + "message": "Hello, world!" + }] def test_get_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) - assert memory.get_medium_term() == [ - {"role": "user", "message": "Hello, world!"} - ] + assert memory.get_medium_term() == [{ + "role": "user", + "message": "Hello, world!" + }] def test_clear_medium_term(): @@ -45,19 +48,18 @@ def test_clear_medium_term(): def test_get_short_term_memory_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert ( - memory.get_short_term_memory_str() - == "[{'role': 'user', 'message': 'Hello, world!'}]" - ) + assert (memory.get_short_term_memory_str() == + "[{'role': 'user', 'message': 'Hello, world!'}]") def test_update_short_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.update_short_term(0, "user", "Goodbye, world!") - assert memory.get_short_term() == [ - {"role": "user", "message": "Goodbye, world!"} - ] + assert memory.get_short_term() == [{ + "role": "user", + "message": "Goodbye, world!" + }] def test_clear(): @@ -71,9 +73,10 @@ def test_search_memory(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.search_memory("Hello") == { - "short_term": [ - (0, {"role": "user", "message": "Hello, world!"}) - ], + "short_term": [(0, { + "role": "user", + "message": "Hello, world!" + })], "medium_term": [], } @@ -81,19 +84,18 @@ def test_search_memory(): def test_return_shortmemory_as_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert ( - memory.return_shortmemory_as_str() - == "[{'role': 'user', 'message': 'Hello, world!'}]" - ) + assert (memory.return_shortmemory_as_str() == + "[{'role': 'user', 'message': 'Hello, world!'}]") def test_move_to_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) - assert memory.get_medium_term() == [ - {"role": "user", "message": "Hello, world!"} - ] + assert memory.get_medium_term() == [{ + "role": "user", + "message": "Hello, world!" + }] assert memory.get_short_term() == [] @@ -101,10 +103,8 @@ def test_return_medium_memory_as_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) - assert ( - memory.return_medium_memory_as_str() - == "[{'role': 'user', 'message': 'Hello, world!'}]" - ) + assert (memory.return_medium_memory_as_str() == + "[{'role': 'user', 'message': 'Hello, world!'}]") def test_thread_safety(): @@ -114,9 +114,7 @@ def test_thread_safety(): for _ in range(1000): memory.add("user", "Hello, world!") - threads = [ - threading.Thread(target=add_messages) for _ in range(10) - ] + threads = [threading.Thread(target=add_messages) for _ in range(10)] for thread in threads: thread.start() for thread in threads: diff --git a/tests/memory/test_sqlite.py b/tests/memory/test_sqlite.py index 49d61ef7..50808817 100644 --- a/tests/memory/test_sqlite.py +++ b/tests/memory/test_sqlite.py @@ -8,9 +8,7 @@ from swarms.memory.sqlite import SQLiteDB @pytest.fixture def db(): conn = sqlite3.connect(":memory:") - conn.execute( - "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" - ) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") conn.commit() return SQLiteDB(":memory:") @@ -30,9 +28,7 @@ def test_delete(db): def test_update(db): db.add("INSERT INTO test (name) VALUES (?)", ("test",)) - db.update( - "UPDATE test SET name = ? WHERE name = ?", ("new", "test") - ) + db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test")) result = db.query("SELECT * FROM test") assert result == [(1, "new")] @@ -45,9 +41,7 @@ def test_query(db): def test_execute_query(db): db.add("INSERT INTO test (name) VALUES (?)", ("test",)) - result = db.execute_query( - "SELECT * FROM test WHERE name = ?", ("test",) - ) + result = db.execute_query("SELECT * FROM test WHERE name = ?", ("test",)) assert result == [(1, "test")] @@ -101,6 +95,4 @@ def test_query_with_wrong_query(db): def test_execute_query_with_wrong_query(db): with pytest.raises(sqlite3.OperationalError): - db.execute_query( - "SELECT * FROM wrong WHERE name = ?", ("test",) - ) + db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",)) diff --git a/tests/memory/test_weaviate.py b/tests/memory/test_weaviate.py index d1a69da0..f05d1b50 100644 --- a/tests/memory/test_weaviate.py +++ b/tests/memory/test_weaviate.py @@ -16,9 +16,7 @@ def weaviate_client_mock(): grpc_port="mock_grpc_port", grpc_secure=False, auth_client_secret="mock_api_key", - additional_headers={ - "X-OpenAI-Api-Key": "mock_openai_api_key" - }, + additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"}, additional_config=Mock(), ) @@ -36,13 +34,15 @@ def weaviate_client_mock(): # Define tests for the WeaviateDB class def test_create_collection(weaviate_client_mock): # Test creating a collection - weaviate_client_mock.create_collection( - "test_collection", [{"name": "property"}] - ) + weaviate_client_mock.create_collection("test_collection", [{ + "name": "property" + }]) weaviate_client_mock.client.collections.create.assert_called_with( name="test_collection", vectorizer_config=None, - properties=[{"name": "property"}], + properties=[{ + "name": "property" + }], ) @@ -51,11 +51,9 @@ def test_add_object(weaviate_client_mock): properties = {"name": "John"} weaviate_client_mock.add("test_collection", properties) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection" - ) + "test_collection") weaviate_client_mock.client.collections.data.insert.assert_called_with( - properties - ) + properties) def test_query_objects(weaviate_client_mock): @@ -63,26 +61,20 @@ def test_query_objects(weaviate_client_mock): query = "name:John" weaviate_client_mock.query("test_collection", query) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection" - ) + "test_collection") weaviate_client_mock.client.collections.query.bm25.assert_called_with( - query=query, limit=10 - ) + query=query, limit=10) def test_update_object(weaviate_client_mock): # Test updating an object object_id = "12345" properties = {"name": "Jane"} - weaviate_client_mock.update( - "test_collection", object_id, properties - ) + weaviate_client_mock.update("test_collection", object_id, properties) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection" - ) + "test_collection") weaviate_client_mock.client.collections.data.update.assert_called_with( - object_id, properties - ) + object_id, properties) def test_delete_object(weaviate_client_mock): @@ -90,25 +82,23 @@ def test_delete_object(weaviate_client_mock): object_id = "12345" weaviate_client_mock.delete("test_collection", object_id) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection" - ) + "test_collection") weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with( - object_id - ) + object_id) -def test_create_collection_with_vectorizer_config( - weaviate_client_mock, -): +def test_create_collection_with_vectorizer_config(weaviate_client_mock,): # Test creating a collection with vectorizer configuration vectorizer_config = {"config_key": "config_value"} - weaviate_client_mock.create_collection( - "test_collection", [{"name": "property"}], vectorizer_config - ) + weaviate_client_mock.create_collection("test_collection", [{ + "name": "property" + }], vectorizer_config) weaviate_client_mock.client.collections.create.assert_called_with( name="test_collection", vectorizer_config=vectorizer_config, - properties=[{"name": "property"}], + properties=[{ + "name": "property" + }], ) @@ -118,11 +108,9 @@ def test_query_objects_with_limit(weaviate_client_mock): limit = 20 weaviate_client_mock.query("test_collection", query, limit) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection" - ) + "test_collection") weaviate_client_mock.client.collections.query.bm25.assert_called_with( - query=query, limit=limit - ) + query=query, limit=limit) def test_query_objects_without_limit(weaviate_client_mock): @@ -130,33 +118,29 @@ def test_query_objects_without_limit(weaviate_client_mock): query = "name:John" weaviate_client_mock.query("test_collection", query) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection" - ) + "test_collection") weaviate_client_mock.client.collections.query.bm25.assert_called_with( - query=query, limit=10 - ) + query=query, limit=10) def test_create_collection_failure(weaviate_client_mock): # Test failure when creating a collection with patch( - "weaviate_client.weaviate.collections.create", - side_effect=Exception("Create error"), + "weaviate_client.weaviate.collections.create", + side_effect=Exception("Create error"), ): - with pytest.raises( - Exception, match="Error creating collection" - ): - weaviate_client_mock.create_collection( - "test_collection", [{"name": "property"}] - ) + with pytest.raises(Exception, match="Error creating collection"): + weaviate_client_mock.create_collection("test_collection", [{ + "name": "property" + }]) def test_add_object_failure(weaviate_client_mock): # Test failure when adding an object properties = {"name": "John"} with patch( - "weaviate_client.weaviate.collections.data.insert", - side_effect=Exception("Insert error"), + "weaviate_client.weaviate.collections.data.insert", + side_effect=Exception("Insert error"), ): with pytest.raises(Exception, match="Error adding object"): weaviate_client_mock.add("test_collection", properties) @@ -166,8 +150,8 @@ def test_query_objects_failure(weaviate_client_mock): # Test failure when querying objects query = "name:John" with patch( - "weaviate_client.weaviate.collections.query.bm25", - side_effect=Exception("Query error"), + "weaviate_client.weaviate.collections.query.bm25", + side_effect=Exception("Query error"), ): with pytest.raises(Exception, match="Error querying objects"): weaviate_client_mock.query("test_collection", query) @@ -178,21 +162,20 @@ def test_update_object_failure(weaviate_client_mock): object_id = "12345" properties = {"name": "Jane"} with patch( - "weaviate_client.weaviate.collections.data.update", - side_effect=Exception("Update error"), + "weaviate_client.weaviate.collections.data.update", + side_effect=Exception("Update error"), ): with pytest.raises(Exception, match="Error updating object"): - weaviate_client_mock.update( - "test_collection", object_id, properties - ) + weaviate_client_mock.update("test_collection", object_id, + properties) def test_delete_object_failure(weaviate_client_mock): # Test failure when deleting an object object_id = "12345" with patch( - "weaviate_client.weaviate.collections.data.delete_by_id", - side_effect=Exception("Delete error"), + "weaviate_client.weaviate.collections.data.delete_by_id", + side_effect=Exception("Delete error"), ): with pytest.raises(Exception, match="Error deleting object"): weaviate_client_mock.delete("test_collection", object_id) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index cc48479a..3ec19e8c 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -8,12 +8,11 @@ from swarms.models.anthropic import Anthropic # Mock the Anthropic API client for testing class MockAnthropicClient: + def __init__(self, *args, **kwargs): pass - def completions_create( - self, prompt, stop_sequences, stream, **kwargs - ): + def completions_create(self, prompt, stop_sequences, stream, **kwargs): return MockAnthropicResponse() @@ -46,9 +45,7 @@ def test_anthropic_init_default_values(anthropic_instance): assert anthropic_instance.streaming is False assert anthropic_instance.default_request_timeout == 600 assert ( - anthropic_instance.anthropic_api_url - == "https://test.anthropic.com" - ) + anthropic_instance.anthropic_api_url == "https://test.anthropic.com") assert anthropic_instance.anthropic_api_key == "test_api_key" @@ -79,9 +76,8 @@ def test_anthropic_default_params(anthropic_instance): } -def test_anthropic_run( - mock_anthropic_env, mock_requests_post, anthropic_instance -): +def test_anthropic_run(mock_anthropic_env, mock_requests_post, + anthropic_instance): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -105,9 +101,8 @@ def test_anthropic_run( ) -def test_anthropic_call( - mock_anthropic_env, mock_requests_post, anthropic_instance -): +def test_anthropic_call(mock_anthropic_env, mock_requests_post, + anthropic_instance): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -131,9 +126,8 @@ def test_anthropic_call( ) -def test_anthropic_exception_handling( - mock_anthropic_env, mock_requests_post, anthropic_instance -): +def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post, + anthropic_instance): mock_response = Mock() mock_response.json.return_value = {"error": "An error occurred"} mock_requests_post.return_value = mock_response @@ -148,6 +142,7 @@ def test_anthropic_exception_handling( class MockAnthropicResponse: + def __init__(self): self.completion = "Mocked Response from Anthropic" @@ -173,9 +168,7 @@ def test_anthropic_async_call_method(anthropic_instance): def test_anthropic_async_stream_method(anthropic_instance): - async_generator = anthropic_instance.async_stream( - "Translate to French." - ) + async_generator = anthropic_instance.async_stream("Translate to French.") for token in async_generator: assert isinstance(token, str) @@ -199,63 +192,51 @@ def test_anthropic_wrap_prompt(anthropic_instance): def test_anthropic_convert_prompt(anthropic_instance): prompt = "What is the meaning of life?" converted_prompt = anthropic_instance.convert_prompt(prompt) - assert converted_prompt.startswith( - anthropic_instance.HUMAN_PROMPT - ) + assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) def test_anthropic_call_with_stop(anthropic_instance): - response = anthropic_instance( - "Translate to French.", stop=["stop1", "stop2"] - ) + response = anthropic_instance("Translate to French.", + stop=["stop1", "stop2"]) assert response == "Mocked Response from Anthropic" def test_anthropic_stream_with_stop(anthropic_instance): - generator = anthropic_instance.stream( - "Write a story.", stop=["stop1", "stop2"] - ) + generator = anthropic_instance.stream("Write a story.", + stop=["stop1", "stop2"]) for token in generator: assert isinstance(token, str) def test_anthropic_async_call_with_stop(anthropic_instance): - response = anthropic_instance.async_call( - "Tell me a joke.", stop=["stop1", "stop2"] - ) + response = anthropic_instance.async_call("Tell me a joke.", + stop=["stop1", "stop2"]) assert response == "Mocked Response from Anthropic" def test_anthropic_async_stream_with_stop(anthropic_instance): - async_generator = anthropic_instance.async_stream( - "Translate to French.", stop=["stop1", "stop2"] - ) + async_generator = anthropic_instance.async_stream("Translate to French.", + stop=["stop1", "stop2"]) for token in async_generator: assert isinstance(token, str) -def test_anthropic_get_num_tokens_with_count_tokens( - anthropic_instance, -): +def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance,): anthropic_instance.count_tokens = Mock(return_value=10) text = "This is a test sentence." num_tokens = anthropic_instance.get_num_tokens(text) assert num_tokens == 10 -def test_anthropic_get_num_tokens_without_count_tokens( - anthropic_instance, -): +def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance,): del anthropic_instance.count_tokens with pytest.raises(NameError): text = "This is a test sentence." anthropic_instance.get_num_tokens(text) -def test_anthropic_wrap_prompt_without_human_ai_prompt( - anthropic_instance, -): +def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance,): del anthropic_instance.HUMAN_PROMPT del anthropic_instance.AI_PROMPT prompt = "What is the meaning of life?" diff --git a/tests/models/test_biogpt.py b/tests/models/test_biogpt.py index e6093729..8f27e62a 100644 --- a/tests/models/test_biogpt.py +++ b/tests/models/test_biogpt.py @@ -48,10 +48,8 @@ def test_cell_biology_response(biogpt_instance): # 40. Test for a question about protein structure def test_protein_structure_response(biogpt_instance): - question = ( - "What's the difference between alpha helix and beta sheet" - " structures in proteins?" - ) + question = ("What's the difference between alpha helix and beta sheet" + " structures in proteins?") response = biogpt_instance(question) assert response assert isinstance(response, str) @@ -83,9 +81,7 @@ def test_bioinformatics_response(biogpt_instance): # 44. Test for a neuroscience question def test_neuroscience_response(biogpt_instance): - question = ( - "Explain the function of synapses in the nervous system." - ) + question = ("Explain the function of synapses in the nervous system.") response = biogpt_instance(question) assert response assert isinstance(response, str) @@ -108,8 +104,11 @@ def test_init(bio_gpt): def test_call(bio_gpt, monkeypatch): + def mock_pipeline(*args, **kwargs): + class MockGenerator: + def __call__(self, text, **kwargs): return ["Generated text"] @@ -167,9 +166,7 @@ def test_get_config_return_type(biogpt_instance): # 28. Test saving model functionality by checking if files are created @patch.object(BioGptForCausalLM, "save_pretrained") @patch.object(BioGptTokenizer, "save_pretrained") -def test_save_model( - mock_save_model, mock_save_tokenizer, biogpt_instance -): +def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): path = "test_path" biogpt_instance.save_model(path) mock_save_model.assert_called_once_with(path) @@ -179,9 +176,7 @@ def test_save_model( # 29. Test loading model from path @patch.object(BioGptForCausalLM, "from_pretrained") @patch.object(BioGptTokenizer, "from_pretrained") -def test_load_from_path( - mock_load_model, mock_load_tokenizer, biogpt_instance -): +def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance): path = "test_path" biogpt_instance.load_from_path(path) mock_load_model.assert_called_once_with(path) @@ -198,9 +193,7 @@ def test_print_model_metadata(biogpt_instance): # 31. Test that beam_search_decoding uses the correct number of beams @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_num_beams( - mock_generate, biogpt_instance -): +def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): biogpt_instance.beam_search_decoding("test_sentence", num_beams=7) _, kwargs = mock_generate.call_args assert kwargs["num_beams"] == 7 @@ -208,12 +201,8 @@ def test_beam_search_decoding_num_beams( # 32. Test if beam_search_decoding handles early_stopping @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_early_stopping( - mock_generate, biogpt_instance -): - biogpt_instance.beam_search_decoding( - "test_sentence", early_stopping=False - ) +def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance): + biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False) _, kwargs = mock_generate.call_args assert kwargs["early_stopping"] is False diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 8a1147d3..347cb9fc 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -42,9 +42,7 @@ def test_cohere_async_api_error_handling(cohere_instance): cohere_instance.model = "base" cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): - cohere_instance.async_call( - "Error handling with invalid API key." - ) + cohere_instance.async_call("Error handling with invalid API key.") def test_cohere_stream_api_error_handling(cohere_instance): @@ -53,8 +51,7 @@ def test_cohere_stream_api_error_handling(cohere_instance): cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): generator = cohere_instance.stream( - "Error handling with invalid API key." - ) + "Error handling with invalid API key.") for token in generator: pass @@ -94,31 +91,26 @@ def test_cohere_convert_prompt(cohere_instance): def test_cohere_call_with_stop(cohere_instance): - response = cohere_instance( - "Translate to French.", stop=["stop1", "stop2"] - ) + response = cohere_instance("Translate to French.", stop=["stop1", "stop2"]) assert response == "Mocked Response from Cohere" def test_cohere_stream_with_stop(cohere_instance): - generator = cohere_instance.stream( - "Write a story.", stop=["stop1", "stop2"] - ) + generator = cohere_instance.stream("Write a story.", + stop=["stop1", "stop2"]) for token in generator: assert isinstance(token, str) def test_cohere_async_call_with_stop(cohere_instance): - response = cohere_instance.async_call( - "Tell me a joke.", stop=["stop1", "stop2"] - ) + response = cohere_instance.async_call("Tell me a joke.", + stop=["stop1", "stop2"]) assert response == "Mocked Response from Cohere" def test_cohere_async_stream_with_stop(cohere_instance): - async_generator = cohere_instance.async_stream( - "Translate to French.", stop=["stop1", "stop2"] - ) + async_generator = cohere_instance.async_stream("Translate to French.", + stop=["stop1", "stop2"]) for token in async_generator: assert isinstance(token, str) @@ -174,12 +166,8 @@ def test_base_cohere_validate_environment_without_cohere(): # Test cases for benchmarking generations with various models def test_cohere_generate_with_command_light(cohere_instance): cohere_instance.model = "command-light" - response = cohere_instance( - "Generate text with Command Light model." - ) - assert response.startswith( - "Generated text with Command Light model" - ) + response = cohere_instance("Generate text with Command Light model.") + assert response.startswith("Generated text with Command Light model") def test_cohere_generate_with_command(cohere_instance): @@ -202,74 +190,54 @@ def test_cohere_generate_with_base(cohere_instance): def test_cohere_generate_with_embed_english_v2(cohere_instance): cohere_instance.model = "embed-english-v2.0" - response = cohere_instance( - "Generate embeddings with English v2.0 model." - ) - assert response.startswith( - "Generated embeddings with English v2.0 model" - ) + response = cohere_instance("Generate embeddings with English v2.0 model.") + assert response.startswith("Generated embeddings with English v2.0 model") def test_cohere_generate_with_embed_english_light_v2(cohere_instance): cohere_instance.model = "embed-english-light-v2.0" response = cohere_instance( - "Generate embeddings with English Light v2.0 model." - ) + "Generate embeddings with English Light v2.0 model.") assert response.startswith( - "Generated embeddings with English Light v2.0 model" - ) + "Generated embeddings with English Light v2.0 model") def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance( - "Generate embeddings with Multilingual v2.0 model." - ) + "Generate embeddings with Multilingual v2.0 model.") assert response.startswith( - "Generated embeddings with Multilingual v2.0 model" - ) + "Generated embeddings with Multilingual v2.0 model") def test_cohere_generate_with_embed_english_v3(cohere_instance): cohere_instance.model = "embed-english-v3.0" - response = cohere_instance( - "Generate embeddings with English v3.0 model." - ) - assert response.startswith( - "Generated embeddings with English v3.0 model" - ) + response = cohere_instance("Generate embeddings with English v3.0 model.") + assert response.startswith("Generated embeddings with English v3.0 model") def test_cohere_generate_with_embed_english_light_v3(cohere_instance): cohere_instance.model = "embed-english-light-v3.0" response = cohere_instance( - "Generate embeddings with English Light v3.0 model." - ) + "Generate embeddings with English Light v3.0 model.") assert response.startswith( - "Generated embeddings with English Light v3.0 model" - ) + "Generated embeddings with English Light v3.0 model") def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance( - "Generate embeddings with Multilingual v3.0 model." - ) + "Generate embeddings with Multilingual v3.0 model.") assert response.startswith( - "Generated embeddings with Multilingual v3.0 model" - ) + "Generated embeddings with Multilingual v3.0 model") -def test_cohere_generate_with_embed_multilingual_light_v3( - cohere_instance, -): +def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance,): cohere_instance.model = "embed-multilingual-light-v3.0" response = cohere_instance( - "Generate embeddings with Multilingual Light v3.0 model." - ) + "Generate embeddings with Multilingual Light v3.0 model.") assert response.startswith( - "Generated embeddings with Multilingual Light v3.0 model" - ) + "Generated embeddings with Multilingual Light v3.0 model") # Add more test cases to benchmark other models and functionalities @@ -299,17 +267,13 @@ def test_cohere_call_with_embed_english_v3_model(cohere_instance): assert isinstance(response, str) -def test_cohere_call_with_embed_multilingual_v2_model( - cohere_instance, -): +def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance("Translate to French.") assert isinstance(response, str) -def test_cohere_call_with_embed_multilingual_v3_model( - cohere_instance, -): +def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance("Translate to French.") assert isinstance(response, str) @@ -329,9 +293,7 @@ def test_cohere_call_with_long_prompt(cohere_instance): def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit." - ) + prompt = ("This is a test prompt that will exceed the max tokens limit.") with pytest.raises(ValueError): cohere_instance(prompt) @@ -364,18 +326,14 @@ def test_cohere_stream_with_embed_english_v3_model(cohere_instance): assert isinstance(token, str) -def test_cohere_stream_with_embed_multilingual_v2_model( - cohere_instance, -): +def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v2.0" generator = cohere_instance.stream("Write a story.") for token in generator: assert isinstance(token, str) -def test_cohere_stream_with_embed_multilingual_v3_model( - cohere_instance, -): +def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v3.0" generator = cohere_instance.stream("Write a story.") for token in generator: @@ -394,33 +352,25 @@ def test_cohere_async_call_with_base_model(cohere_instance): assert isinstance(response, str) -def test_cohere_async_call_with_embed_english_v2_model( - cohere_instance, -): +def test_cohere_async_call_with_embed_english_v2_model(cohere_instance,): cohere_instance.model = "embed-english-v2.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_english_v3_model( - cohere_instance, -): +def test_cohere_async_call_with_embed_english_v3_model(cohere_instance,): cohere_instance.model = "embed-english-v3.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_multilingual_v2_model( - cohere_instance, -): +def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_multilingual_v3_model( - cohere_instance, -): +def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) @@ -440,36 +390,28 @@ def test_cohere_async_stream_with_base_model(cohere_instance): assert isinstance(token, str) -def test_cohere_async_stream_with_embed_english_v2_model( - cohere_instance, -): +def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance,): cohere_instance.model = "embed-english-v2.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_english_v3_model( - cohere_instance, -): +def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance,): cohere_instance.model = "embed-english-v3.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_multilingual_v2_model( - cohere_instance, -): +def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v2.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_multilingual_v3_model( - cohere_instance, -): +def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,): cohere_instance.model = "embed-multilingual-v3.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: @@ -479,9 +421,7 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model( def test_cohere_representation_model_embedding(cohere_instance): # Test using the Representation model for text embedding cohere_instance.model = "embed-english-v3.0" - embedding = cohere_instance.embed( - "Generate an embedding for this text." - ) + embedding = cohere_instance.embed("Generate an embedding for this text.") assert isinstance(embedding, list) assert len(embedding) > 0 @@ -495,26 +435,20 @@ def test_cohere_representation_model_classification(cohere_instance): assert "score" in classification -def test_cohere_representation_model_language_detection( - cohere_instance, -): +def test_cohere_representation_model_language_detection(cohere_instance,): # Test using the Representation model for language detection cohere_instance.model = "embed-english-v3.0" language = cohere_instance.detect_language( - "Detect the language of this text." - ) + "Detect the language of this text.") assert isinstance(language, str) def test_cohere_representation_model_max_tokens_limit_exceeded( - cohere_instance, -): + cohere_instance,): # Test handling max tokens limit exceeded error cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit." - ) + prompt = ("This is a test prompt that will exceed the max tokens limit.") with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -522,102 +456,80 @@ def test_cohere_representation_model_max_tokens_limit_exceeded( # Add more production-grade test cases based on real-world scenarios -def test_cohere_representation_model_multilingual_embedding( - cohere_instance, -): +def test_cohere_representation_model_multilingual_embedding(cohere_instance,): # Test using the Representation model for multilingual text embedding cohere_instance.model = "embed-multilingual-v3.0" - embedding = cohere_instance.embed( - "Generate multilingual embeddings." - ) + embedding = cohere_instance.embed("Generate multilingual embeddings.") assert isinstance(embedding, list) assert len(embedding) > 0 def test_cohere_representation_model_multilingual_classification( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for multilingual text classification cohere_instance.model = "embed-multilingual-v3.0" - classification = cohere_instance.classify( - "Classify multilingual text." - ) + classification = cohere_instance.classify("Classify multilingual text.") assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_multilingual_language_detection( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for multilingual language detection cohere_instance.model = "embed-multilingual-v3.0" language = cohere_instance.detect_language( - "Detect the language of multilingual text." - ) + "Detect the language of multilingual text.") assert isinstance(language, str) def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( - cohere_instance, -): + cohere_instance,): # Test handling max tokens limit exceeded error for multilingual model cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for multilingual model." - ) + prompt = ("This is a test prompt that will exceed the max tokens limit" + " for multilingual model.") with pytest.raises(ValueError): cohere_instance.embed(prompt) def test_cohere_representation_model_multilingual_light_embedding( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for multilingual light text embedding cohere_instance.model = "embed-multilingual-light-v3.0" - embedding = cohere_instance.embed( - "Generate multilingual light embeddings." - ) + embedding = cohere_instance.embed("Generate multilingual light embeddings.") assert isinstance(embedding, list) assert len(embedding) > 0 def test_cohere_representation_model_multilingual_light_classification( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for multilingual light text classification cohere_instance.model = "embed-multilingual-light-v3.0" classification = cohere_instance.classify( - "Classify multilingual light text." - ) + "Classify multilingual light text.") assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_multilingual_light_language_detection( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for multilingual light language detection cohere_instance.model = "embed-multilingual-light-v3.0" language = cohere_instance.detect_language( - "Detect the language of multilingual light text." - ) + "Detect the language of multilingual light text.") assert isinstance(language, str) def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded( - cohere_instance, -): + cohere_instance,): # Test handling max tokens limit exceeded error for multilingual light model cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for multilingual light model." - ) + prompt = ("This is a test prompt that will exceed the max tokens limit" + " for multilingual light model.") with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -625,18 +537,14 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede def test_cohere_command_light_model(cohere_instance): # Test using the Command Light model for text generation cohere_instance.model = "command-light" - response = cohere_instance( - "Generate text using Command Light model." - ) + response = cohere_instance("Generate text using Command Light model.") assert isinstance(response, str) def test_cohere_base_light_model(cohere_instance): # Test using the Base Light model for text generation cohere_instance.model = "base-light" - response = cohere_instance( - "Generate text using Base Light model." - ) + response = cohere_instance("Generate text using Base Light model.") assert isinstance(response, str) @@ -647,9 +555,7 @@ def test_cohere_generate_summarize_endpoint(cohere_instance): assert isinstance(response, str) -def test_cohere_representation_model_english_embedding( - cohere_instance, -): +def test_cohere_representation_model_english_embedding(cohere_instance,): # Test using the Representation model for English text embedding cohere_instance.model = "embed-english-v3.0" embedding = cohere_instance.embed("Generate English embeddings.") @@ -657,90 +563,69 @@ def test_cohere_representation_model_english_embedding( assert len(embedding) > 0 -def test_cohere_representation_model_english_classification( - cohere_instance, -): +def test_cohere_representation_model_english_classification(cohere_instance,): # Test using the Representation model for English text classification cohere_instance.model = "embed-english-v3.0" - classification = cohere_instance.classify( - "Classify English text." - ) + classification = cohere_instance.classify("Classify English text.") assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_english_language_detection( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for English language detection cohere_instance.model = "embed-english-v3.0" language = cohere_instance.detect_language( - "Detect the language of English text." - ) + "Detect the language of English text.") assert isinstance(language, str) def test_cohere_representation_model_english_max_tokens_limit_exceeded( - cohere_instance, -): + cohere_instance,): # Test handling max tokens limit exceeded error for English model cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for English model." - ) + prompt = ("This is a test prompt that will exceed the max tokens limit" + " for English model.") with pytest.raises(ValueError): cohere_instance.embed(prompt) -def test_cohere_representation_model_english_light_embedding( - cohere_instance, -): +def test_cohere_representation_model_english_light_embedding(cohere_instance,): # Test using the Representation model for English light text embedding cohere_instance.model = "embed-english-light-v3.0" - embedding = cohere_instance.embed( - "Generate English light embeddings." - ) + embedding = cohere_instance.embed("Generate English light embeddings.") assert isinstance(embedding, list) assert len(embedding) > 0 def test_cohere_representation_model_english_light_classification( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for English light text classification cohere_instance.model = "embed-english-light-v3.0" - classification = cohere_instance.classify( - "Classify English light text." - ) + classification = cohere_instance.classify("Classify English light text.") assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_english_light_language_detection( - cohere_instance, -): + cohere_instance,): # Test using the Representation model for English light language detection cohere_instance.model = "embed-english-light-v3.0" language = cohere_instance.detect_language( - "Detect the language of English light text." - ) + "Detect the language of English light text.") assert isinstance(language, str) def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( - cohere_instance, -): + cohere_instance,): # Test handling max tokens limit exceeded error for English light model cohere_instance.model = "embed-english-light-v3.0" cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for English light model." - ) + prompt = ("This is a test prompt that will exceed the max tokens limit" + " for English light model.") with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -748,9 +633,7 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( def test_cohere_command_model(cohere_instance): # Test using the Command model for text generation cohere_instance.model = "command" - response = cohere_instance( - "Generate text using the Command model." - ) + response = cohere_instance("Generate text using the Command model.") assert isinstance(response, str) @@ -764,9 +647,7 @@ def test_cohere_invalid_model(cohere_instance): cohere_instance("Generate text using an invalid model.") -def test_cohere_base_model_generation_with_max_tokens( - cohere_instance, -): +def test_cohere_base_model_generation_with_max_tokens(cohere_instance,): # Test generating text using the base model with a specified max_tokens limit cohere_instance.model = "base" cohere_instance.max_tokens = 20 diff --git a/tests/models/test_elevenlab.py b/tests/models/test_elevenlab.py index da41ca53..11359f9a 100644 --- a/tests/models/test_elevenlab.py +++ b/tests/models/test_elevenlab.py @@ -30,45 +30,30 @@ def test_run_text_to_speech(eleven_labs_tool): def test_play_speech(eleven_labs_tool): - with patch( - "builtins.open", mock_open(read_data="fake_audio_data") - ): + with patch("builtins.open", mock_open(read_data="fake_audio_data")): eleven_labs_tool.play(EXPECTED_SPEECH_FILE) def test_stream_speech(eleven_labs_tool): - with patch( - "tempfile.NamedTemporaryFile", mock_open() - ) as mock_file: + with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file: eleven_labs_tool.stream_speech(SAMPLE_TEXT) - mock_file.assert_called_with( - mode="bx", suffix=".wav", delete=False - ) + mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False) # Testing fixture and environment variables def test_api_key_validation(eleven_labs_tool): - with patch( - "langchain.utils.get_from_dict_or_env", return_value=API_KEY - ): + with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY): values = {"eleven_api_key": None} - validated_values = eleven_labs_tool.validate_environment( - values - ) + validated_values = eleven_labs_tool.validate_environment(values) assert "eleven_api_key" in validated_values # Mocking the external library def test_run_text_to_speech_with_mock(eleven_labs_tool): - with patch( - "tempfile.NamedTemporaryFile", mock_open() - ) as mock_file, patch( - "your_module._import_elevenlabs" - ) as mock_elevenlabs: + with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch( + "your_module._import_elevenlabs") as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value - mock_elevenlabs_instance.generate.return_value = ( - b"fake_audio_data" - ) + mock_elevenlabs_instance.generate.return_value = (b"fake_audio_data") eleven_labs_tool.run(SAMPLE_TEXT) assert mock_file.call_args[1]["suffix"] == ".wav" assert mock_file.call_args[1]["delete"] is False @@ -80,14 +65,11 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): with patch("your_module._import_elevenlabs") as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value mock_elevenlabs_instance.generate.side_effect = Exception( - "Test Exception" - ) + "Test Exception") with pytest.raises( - RuntimeError, - match=( - "Error while running ElevenLabsText2SpeechTool: Test" - " Exception" - ), + RuntimeError, + match=("Error while running ElevenLabsText2SpeechTool: Test" + " Exception"), ): eleven_labs_tool.run(SAMPLE_TEXT) @@ -97,9 +79,7 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): "model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL], ) -def test_run_text_to_speech_with_different_models( - eleven_labs_tool, model -): +def test_run_text_to_speech_with_different_models(eleven_labs_tool, model): eleven_labs_tool.model = model speech_file = eleven_labs_tool.run(SAMPLE_TEXT) assert isinstance(speech_file, str) diff --git a/tests/models/test_fire_function_caller.py b/tests/models/test_fire_function_caller.py index 082d954d..5e859272 100644 --- a/tests/models/test_fire_function_caller.py +++ b/tests/models/test_fire_function_caller.py @@ -39,6 +39,4 @@ def test_fire_function_caller_run(mocker): tokenizer.batch_decode.assert_called_once_with(generated_ids) # Assert the decoded output is printed - assert decoded_output in mocker.patch.object( - print, "call_args_list" - ) + assert decoded_output in mocker.patch.object(print, "call_args_list") diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py index e76e11bb..79ce79e5 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/test_fuyu.py @@ -38,9 +38,7 @@ def fuyu_instance(): # Test using the fixture. def test_fuyu_processor_initialization(fuyu_instance): assert isinstance(fuyu_instance.processor, FuyuProcessor) - assert isinstance( - fuyu_instance.image_processor, FuyuImageProcessor - ) + assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) # Test exception when providing an invalid image path. @@ -51,6 +49,7 @@ def test_invalid_image_path(fuyu_instance): # Using monkeypatch to replace the Image.open method to simulate a failure. def test_image_open_failure(fuyu_instance, monkeypatch): + def mock_open(*args, **kwargs): raise Exception("Mocked failure") @@ -79,13 +78,9 @@ def test_tokenizer_type(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance): - assert ( - fuyu_instance.processor.image_processor - == fuyu_instance.image_processor - ) - assert ( - fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer - ) + assert (fuyu_instance.processor.image_processor == + fuyu_instance.image_processor) + assert (fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer) def test_model_device_map(fuyu_instance): @@ -144,22 +139,14 @@ def test_get_img_invalid_path(fuyu_instance): # Test `run` method with valid inputs def test_run_valid_inputs(fuyu_instance): - with patch.object( - fuyu_instance, "get_img" - ) as mock_get_img, patch.object( - fuyu_instance, "processor" - ) as mock_processor, patch.object( - fuyu_instance, "model" - ) as mock_model: + with patch.object(fuyu_instance, "get_img") as mock_get_img, patch.object( + fuyu_instance, "processor") as mock_processor, patch.object( + fuyu_instance, "model") as mock_model: mock_get_img.return_value = "Test image" - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } + mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test text"] - result = fuyu_instance.run( - "Hello, world!", "valid/path/to/image.png" - ) + result = fuyu_instance.run("Hello, world!", "valid/path/to/image.png") assert result == ["Test text"] @@ -186,9 +173,7 @@ def test_run_invalid_image_path(fuyu_instance): with patch.object(fuyu_instance, "get_img") as mock_get_img: mock_get_img.side_effect = FileNotFoundError with pytest.raises(FileNotFoundError): - fuyu_instance.run( - "Hello, world!", "invalid/path/to/image.png" - ) + fuyu_instance.run("Hello, world!", "invalid/path/to/image.png") # Test `__init__` method with default parameters diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a61d1676..efe716ff 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -24,12 +24,8 @@ def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model): assert model.model is mock_genai_model -def test_gemini_init_custom_params( - mock_gemini_api_key, mock_genai_model -): - model = Gemini( - model_name="custom-model", gemini_api_key="custom-api-key" - ) +def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model): + model = Gemini(model_name="custom-model", gemini_api_key="custom-api-key") assert model.model_name == "custom-model" assert model.gemini_api_key == "custom-api-key" assert model.model is mock_genai_model @@ -54,16 +50,13 @@ def test_gemini_run_with_img( response = model.run(task=task, img=img) assert response == "Generated response" - mock_generate_content.assert_called_with( - content=[task, "Processed image"] - ) + mock_generate_content.assert_called_with(content=[task, "Processed image"]) mock_process_img.assert_called_with(img=img) @patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_without_img( - mock_generate_content, mock_gemini_api_key, mock_genai_model -): +def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key, + mock_genai_model): model = Gemini() task = "A cat" response_mock = Mock(text="Generated response") @@ -76,9 +69,8 @@ def test_gemini_run_without_img( @patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_exception( - mock_generate_content, mock_gemini_api_key, mock_genai_model -): +def test_gemini_run_exception(mock_generate_content, mock_gemini_api_key, + mock_genai_model): model = Gemini() task = "A cat" mock_generate_content.side_effect = Exception("Test exception") @@ -96,30 +88,23 @@ def test_gemini_process_img(mock_gemini_api_key, mock_genai_model): with patch("builtins.open", create=True) as open_mock: open_mock.return_value.__enter__.return_value.read.return_value = ( - img_data - ) + img_data) processed_img = model.process_img(img) - assert processed_img == [ - {"mime_type": "image/png", "data": img_data} - ] + assert processed_img == [{"mime_type": "image/png", "data": img_data}] open_mock.assert_called_with(img, "rb") # Test Gemini initialization with missing API key def test_gemini_init_missing_api_key(): - with pytest.raises( - ValueError, match="Please provide a Gemini API key" - ): + with pytest.raises(ValueError, match="Please provide a Gemini API key"): Gemini(gemini_api_key=None) # Test Gemini initialization with missing model name def test_gemini_init_missing_model_name(): - with pytest.raises( - ValueError, match="Please provide a model name" - ): + with pytest.raises(ValueError, match="Please provide a model name"): Gemini(model_name=None) @@ -141,26 +126,20 @@ def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model): # Test Gemini process_img method with missing image -def test_gemini_process_img_missing_image( - mock_gemini_api_key, mock_genai_model -): +def test_gemini_process_img_missing_image(mock_gemini_api_key, + mock_genai_model): model = Gemini() img = None - with pytest.raises( - ValueError, match="Please provide an image to process" - ): + with pytest.raises(ValueError, match="Please provide an image to process"): model.process_img(img=img) # Test Gemini process_img method with missing image type -def test_gemini_process_img_missing_image_type( - mock_gemini_api_key, mock_genai_model -): +def test_gemini_process_img_missing_image_type(mock_gemini_api_key, + mock_genai_model): model = Gemini() img = "cat.png" - with pytest.raises( - ValueError, match="Please provide the image type" - ): + with pytest.raises(ValueError, match="Please provide the image type"): model.process_img(img=img, type=None) @@ -168,9 +147,7 @@ def test_gemini_process_img_missing_image_type( def test_gemini_process_img_missing_api_key(mock_genai_model): model = Gemini(gemini_api_key=None) img = "cat.png" - with pytest.raises( - ValueError, match="Please provide a Gemini API key" - ): + with pytest.raises(ValueError, match="Please provide a Gemini API key"): model.process_img(img=img, type="image/png") @@ -193,9 +170,7 @@ def test_gemini_run_mock_img_processing( response = model.run(task=task, img=img) assert response == "Generated response" - mock_generate_content.assert_called_with( - content=[task, "Processed image"] - ) + mock_generate_content.assert_called_with(content=[task, "Processed image"]) mock_process_img.assert_called_with(img=img) diff --git a/tests/models/test_gigabind.py b/tests/models/test_gigabind.py index 3aae0739..493dbca3 100644 --- a/tests/models/test_gigabind.py +++ b/tests/models/test_gigabind.py @@ -11,16 +11,13 @@ except ImportError: @pytest.fixture def api(): - return Gigabind( - host="localhost", port=8000, endpoint="embeddings" - ) + return Gigabind(host="localhost", port=8000, endpoint="embeddings") @pytest.fixture def mock(requests_mock): - requests_mock.post( - "http://localhost:8000/embeddings", json={"result": "success"} - ) + requests_mock.post("http://localhost:8000/embeddings", + json={"result": "success"}) return requests_mock @@ -40,9 +37,9 @@ def test_run_with_audio(api, mock): def test_run_with_all(api, mock): - response = api.run( - text="Hello, world!", vision="image.jpg", audio="audio.mp3" - ) + response = api.run(text="Hello, world!", + vision="image.jpg", + audio="audio.mp3") assert response == {"result": "success"} @@ -65,9 +62,20 @@ def test_retry_on_failure(api, requests_mock): requests_mock.post( "http://localhost:8000/embeddings", [ - {"status_code": 500, "json": {}}, - {"status_code": 500, "json": {}}, - {"status_code": 200, "json": {"result": "success"}}, + { + "status_code": 500, + "json": {} + }, + { + "status_code": 500, + "json": {} + }, + { + "status_code": 200, + "json": { + "result": "success" + } + }, ], ) response = api.run(text="Hello, world!") @@ -78,9 +86,18 @@ def test_retry_exhausted(api, requests_mock): requests_mock.post( "http://localhost:8000/embeddings", [ - {"status_code": 500, "json": {}}, - {"status_code": 500, "json": {}}, - {"status_code": 500, "json": {}}, + { + "status_code": 500, + "json": {} + }, + { + "status_code": 500, + "json": {} + }, + { + "status_code": 500, + "json": {} + }, ], ) response = api.run(text="Hello, world!") @@ -93,9 +110,7 @@ def test_proxy_url(api): def test_invalid_response(api, requests_mock): - requests_mock.post( - "http://localhost:8000/embeddings", text="not json" - ) + requests_mock.post("http://localhost:8000/embeddings", text="not json") response = api.run(text="Hello, world!") assert response is None @@ -110,9 +125,7 @@ def test_connection_error(api, requests_mock): def test_http_error(api, requests_mock): - requests_mock.post( - "http://localhost:8000/embeddings", status_code=500 - ) + requests_mock.post("http://localhost:8000/embeddings", status_code=500) response = api.run(text="Hello, world!") assert response is None @@ -148,9 +161,7 @@ def test_run_with_large_all(api, mock): large_text = "Hello, world! " * 10000 # 10,000 repetitions large_vision = "image.jpg" * 10000 # 10,000 repetitions large_audio = "audio.mp3" * 10000 # 10,000 repetitions - response = api.run( - text=large_text, vision=large_vision, audio=large_audio - ) + response = api.run(text=large_text, vision=large_vision, audio=large_audio) assert response == {"result": "success"} diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py index ac797280..434ee502 100644 --- a/tests/models/test_gpt4_vision_api.py +++ b/tests/models/test_gpt4_vision_api.py @@ -26,9 +26,9 @@ def test_init(vision_api): def test_encode_image(vision_api): with patch( - "builtins.open", - mock_open(read_data=b"test_image_data"), - create=True, + "builtins.open", + mock_open(read_data=b"test_image_data"), + create=True, ): encoded_image = vision_api.encode_image(img) assert encoded_image == "dGVzdF9pbWFnZV9kYXRh" @@ -37,8 +37,8 @@ def test_encode_image(vision_api): def test_run_success(vision_api): expected_response = {"This is the model's response."} with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: result = vision_api.run("What is this?", img) mock_post.assert_called_once() @@ -46,9 +46,7 @@ def test_run_success(vision_api): def test_run_request_error(vision_api): - with patch( - "requests.post", side_effect=RequestException("Request Error") - ): + with patch("requests.post", side_effect=RequestException("Request Error")): with pytest.raises(RequestException): vision_api.run("What is this?", img) @@ -56,20 +54,18 @@ def test_run_request_error(vision_api): def test_run_response_error(vision_api): expected_response = {"error": "Model Error"} with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ): with pytest.raises(RuntimeError): vision_api.run("What is this?", img) def test_call(vision_api): - expected_response = { - "choices": [{"text": "This is the model's response."}] - } + expected_response = {"choices": [{"text": "This is the model's response."}]} with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: result = vision_api("What is this?", img) mock_post.assert_called_once() @@ -95,9 +91,7 @@ def test_initialization_with_custom_key(): def test_run_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch( - "requests.post", side_effect=Exception("Test Exception") - ): + with patch("requests.post", side_effect=Exception("Test Exception")): with pytest.raises(Exception): gpt_api.run(task, img_url) @@ -105,14 +99,10 @@ def test_run_with_exception(gpt_api): def test_call_method_successful_response(gpt_api): task = "What is in the image?" img_url = img - response_json = { - "choices": [{"text": "Answer from GPT-4 Vision"}] - } + response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]} mock_response = Mock() mock_response.json.return_value = response_json - with patch( - "requests.post", return_value=mock_response - ) as mock_post: + with patch("requests.post", return_value=mock_response) as mock_post: result = gpt_api(task, img_url) mock_post.assert_called_once() assert result == response_json @@ -121,9 +111,7 @@ def test_call_method_successful_response(gpt_api): def test_call_method_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch( - "requests.post", side_effect=Exception("Test Exception") - ): + with patch("requests.post", side_effect=Exception("Test Exception")): with pytest.raises(Exception): gpt_api(task, img_url) @@ -131,16 +119,17 @@ def test_call_method_with_exception(gpt_api): @pytest.mark.asyncio async def test_arun_success(vision_api): expected_response = { - "choices": [ - {"message": {"content": "This is the model's response."}} - ] + "choices": [{ + "message": { + "content": "This is the model's response." + } + }] } with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(return_value=expected_response) - ), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock(json=AsyncMock( + return_value=expected_response)), ) as mock_post: result = await vision_api.arun("What is this?", img) mock_post.assert_called_once() @@ -150,9 +139,9 @@ async def test_arun_success(vision_api): @pytest.mark.asyncio async def test_arun_request_error(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=Exception("Request Error"), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + side_effect=Exception("Request Error"), ): with pytest.raises(Exception): await vision_api.arun("What is this?", img) @@ -160,13 +149,15 @@ async def test_arun_request_error(vision_api): def test_run_many_success(vision_api): expected_response = { - "choices": [ - {"message": {"content": "This is the model's response."}} - ] + "choices": [{ + "message": { + "content": "This is the model's response." + } + }] } with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: tasks = ["What is this?", "What is that?"] imgs = [img, img] @@ -179,9 +170,7 @@ def test_run_many_success(vision_api): def test_run_many_request_error(vision_api): - with patch( - "requests.post", side_effect=RequestException("Request Error") - ): + with patch("requests.post", side_effect=RequestException("Request Error")): tasks = ["What is this?", "What is that?"] imgs = [img, img] with pytest.raises(RequestException): @@ -191,11 +180,9 @@ def test_run_many_request_error(vision_api): @pytest.mark.asyncio async def test_arun_json_decode_error(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(side_effect=ValueError) - ), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock(json=AsyncMock(side_effect=ValueError)), ): with pytest.raises(ValueError): await vision_api.arun("What is this?", img) @@ -205,11 +192,9 @@ async def test_arun_json_decode_error(vision_api): async def test_arun_api_error(vision_api): error_response = {"error": {"message": "API Error"}} with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(return_value=error_response) - ), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock(json=AsyncMock(return_value=error_response)), ): with pytest.raises(Exception, match="API Error"): await vision_api.arun("What is this?", img) @@ -219,11 +204,10 @@ async def test_arun_api_error(vision_api): async def test_arun_unexpected_response(vision_api): unexpected_response = {"unexpected": "response"} with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(return_value=unexpected_response) - ), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock(json=AsyncMock( + return_value=unexpected_response)), ): with pytest.raises(Exception, match="Unexpected response"): await vision_api.arun("What is this?", img) @@ -232,9 +216,9 @@ async def test_arun_unexpected_response(vision_api): @pytest.mark.asyncio async def test_arun_retries(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=ClientResponseError(None, None), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + side_effect=ClientResponseError(None, None), ) as mock_post: with pytest.raises(ClientResponseError): await vision_api.arun("What is this?", img) @@ -244,9 +228,9 @@ async def test_arun_retries(vision_api): @pytest.mark.asyncio async def test_arun_timeout(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=asyncio.TimeoutError, + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + side_effect=asyncio.TimeoutError, ): with pytest.raises(asyncio.TimeoutError): await vision_api.arun("What is this?", img) diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py index cbbba940..31fc41d4 100644 --- a/tests/models/test_hf.py +++ b/tests/models/test_hf.py @@ -133,9 +133,7 @@ def test_llm_set_repitition_penalty(llm_instance): def test_llm_set_no_repeat_ngram_size(llm_instance): new_no_repeat_ngram_size = 6 llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size) - assert ( - llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size - ) + assert (llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size) # Test for setting temperature @@ -185,9 +183,7 @@ def test_llm_set_model_id(llm_instance): # Test for setting model -@patch( - "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" -) +@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained") def test_llm_set_model(mock_model, llm_instance): mock_model.return_value = "mocked model" llm_instance.set_model(mock_model) diff --git a/tests/models/test_hf_pipeline.py b/tests/models/test_hf_pipeline.py index 8580dd56..ec306797 100644 --- a/tests/models/test_hf_pipeline.py +++ b/tests/models/test_hf_pipeline.py @@ -14,19 +14,14 @@ def mock_pipeline(): @pytest.fixture def pipeline(mock_pipeline): - return HuggingfacePipeline( - "text-generation", "meta-llama/Llama-2-13b-chat-hf" - ) + return HuggingfacePipeline("text-generation", + "meta-llama/Llama-2-13b-chat-hf") def test_init(pipeline, mock_pipeline): assert pipeline.task_type == "text-generation" assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf" - assert ( - pipeline.use_fp8 is True - if torch.cuda.is_available() - else False - ) + assert (pipeline.use_fp8 is True if torch.cuda.is_available() else False) mock_pipeline.assert_called_once_with( "text-generation", "meta-llama/Llama-2-13b-chat-hf", @@ -51,6 +46,5 @@ def test_run_with_different_task(pipeline, mock_pipeline): mock_pipeline.return_value = "Generated text" result = pipeline.run("text-classification", "Hello, world!") assert result == "Generated text" - mock_pipeline.assert_called_once_with( - "text-classification", "Hello, world!" - ) + mock_pipeline.assert_called_once_with("text-classification", + "Hello, world!") diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 7e19a056..5848e42f 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -18,10 +18,7 @@ def llm_instance(): # Test for instantiation and attributes def test_llm_initialization(llm_instance): - assert ( - llm_instance.model_id - == "NousResearch/Nous-Hermes-2-Vision-Alpha" - ) + assert (llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha") assert llm_instance.max_length == 500 # ... add more assertions for all default attributes @@ -88,15 +85,12 @@ def test_llm_memory_consumption(llm_instance): ) def test_llm_initialization_params(model_id, max_length): if max_length: - instance = HuggingfaceLLM( - model_id=model_id, max_length=max_length - ) + instance = HuggingfaceLLM(model_id=model_id, max_length=max_length) assert instance.max_length == max_length else: instance = HuggingfaceLLM(model_id=model_id) - assert ( - instance.max_length == 500 - ) # Assuming 500 is the default max_length + assert (instance.max_length == 500 + ) # Assuming 500 is the default max_length # Test for setting an invalid device @@ -144,9 +138,7 @@ def test_llm_run_output_length(mock_run, llm_instance): # Test the tokenizer handling special tokens correctly @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode") -def test_llm_tokenizer_special_tokens( - mock_decode, mock_encode, llm_instance -): +def test_llm_tokenizer_special_tokens(mock_decode, mock_encode, llm_instance): mock_encode.return_value = "encoded input with special tokens" mock_decode.return_value = "decoded output with special tokens" result = llm_instance.run("test task with special tokens") @@ -172,9 +164,8 @@ def test_llm_response_time(mock_run, llm_instance): start_time = time.time() llm_instance.run("test task for response time") end_time = time.time() - assert ( - end_time - start_time < 1 - ) # Assuming the response should be faster than 1 second + assert (end_time - start_time + < 1) # Assuming the response should be faster than 1 second # Test the logging of a warning for long inputs @@ -197,13 +188,10 @@ def test_llm_run_model_exception(mock_generate, llm_instance): # Test the behavior when GPU is forced but not available @patch("torch.cuda.is_available", return_value=False) -def test_llm_force_gpu_when_unavailable( - mock_is_available, llm_instance -): +def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): with pytest.raises(EnvironmentError): llm_instance.set_device( - "cuda" - ) # Attempt to set CUDA when it's not available + "cuda") # Attempt to set CUDA when it's not available # Test for proper cleanup after model use (releasing resources) @@ -221,9 +209,8 @@ def test_llm_multilingual_input(mock_run, llm_instance): mock_run.return_value = "mocked multilingual output" multilingual_input = "Bonjour, ceci est un test multilingue." result = llm_instance.run(multilingual_input) - assert isinstance( - result, str - ) # Simple check to ensure output is string type + assert isinstance(result, + str) # Simple check to ensure output is string type # Test caching mechanism to prevent re-running the same inputs diff --git a/tests/models/test_idefics.py b/tests/models/test_idefics.py index 3bfee679..cde670f5 100644 --- a/tests/models/test_idefics.py +++ b/tests/models/test_idefics.py @@ -13,8 +13,8 @@ from swarms.models.idefics import ( @pytest.fixture def idefics_instance(): with patch( - "torch.cuda.is_available", return_value=False - ): # Assuming tests are run on CPU for simplicity + "torch.cuda.is_available", + return_value=False): # Assuming tests are run on CPU for simplicity instance = Idefics() return instance @@ -36,8 +36,8 @@ def test_init_default(idefics_instance): ) def test_init_device(device, expected): with patch( - "torch.cuda.is_available", - return_value=True if expected == "cuda" else False, + "torch.cuda.is_available", + return_value=True if expected == "cuda" else False, ): instance = Idefics(device=device) assert instance.device == expected @@ -46,14 +46,10 @@ def test_init_device(device, expected): # Test `run` method def test_run(idefics_instance): prompts = [["User: Test"]] - with patch.object( - idefics_instance, "processor" - ) as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } + with patch.object(idefics_instance, + "processor") as mock_processor, patch.object( + idefics_instance, "model") as mock_model: + mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -65,14 +61,10 @@ def test_run(idefics_instance): # Test `__call__` method (using the same logic as run for simplicity) def test_call(idefics_instance): prompts = [["User: Test"]] - with patch.object( - idefics_instance, "processor" - ) as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } + with patch.object(idefics_instance, + "processor") as mock_processor, patch.object( + idefics_instance, "model") as mock_model: + mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -85,9 +77,7 @@ def test_call(idefics_instance): def test_chat(idefics_instance): user_input = "User: Hello" response = "Model: Hi there!" - with patch.object( - idefics_instance, "run", return_value=[response] - ): + with patch.object(idefics_instance, "run", return_value=[response]): result = idefics_instance.chat(user_input) assert result == response @@ -97,16 +87,13 @@ def test_chat(idefics_instance): # Test `set_checkpoint` method def test_set_checkpoint(idefics_instance): new_checkpoint = "new_checkpoint" - with patch.object( - IdeficsForVisionText2Text, "from_pretrained" - ) as mock_from_pretrained, patch.object( - AutoProcessor, "from_pretrained" - ): + with patch.object(IdeficsForVisionText2Text, + "from_pretrained") as mock_from_pretrained, patch.object( + AutoProcessor, "from_pretrained"): idefics_instance.set_checkpoint(new_checkpoint) - mock_from_pretrained.assert_called_with( - new_checkpoint, torch_dtype=torch.bfloat16 - ) + mock_from_pretrained.assert_called_with(new_checkpoint, + torch_dtype=torch.bfloat16) # Test `set_device` method @@ -135,7 +122,7 @@ def test_clear_chat_history(idefics_instance): # Exception Tests def test_run_with_empty_prompts(idefics_instance): with pytest.raises( - Exception + Exception ): # Replace Exception with the actual exception that may arise for an empty prompt. idefics_instance.run([]) @@ -143,14 +130,10 @@ def test_run_with_empty_prompts(idefics_instance): # Test `run` method with batched_mode set to False def test_run_batched_mode_false(idefics_instance): task = "User: Test" - with patch.object( - idefics_instance, "processor" - ) as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } + with patch.object(idefics_instance, + "processor") as mock_processor, patch.object( + idefics_instance, "model") as mock_model: + mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -163,9 +146,7 @@ def test_run_batched_mode_false(idefics_instance): # Test `run` method with an exception def test_run_with_exception(idefics_instance): task = "User: Test" - with patch.object( - idefics_instance, "processor" - ) as mock_processor: + with patch.object(idefics_instance, "processor") as mock_processor: mock_processor.side_effect = Exception("Test exception") with pytest.raises(Exception): idefics_instance.run(task) @@ -174,24 +155,21 @@ def test_run_with_exception(idefics_instance): # Test `set_model_name` method def test_set_model_name(idefics_instance): new_model_name = "new_model_name" - with patch.object( - IdeficsForVisionText2Text, "from_pretrained" - ) as mock_from_pretrained, patch.object( - AutoProcessor, "from_pretrained" - ): + with patch.object(IdeficsForVisionText2Text, + "from_pretrained") as mock_from_pretrained, patch.object( + AutoProcessor, "from_pretrained"): idefics_instance.set_model_name(new_model_name) assert idefics_instance.model_name == new_model_name - mock_from_pretrained.assert_called_with( - new_model_name, torch_dtype=torch.bfloat16 - ) + mock_from_pretrained.assert_called_with(new_model_name, + torch_dtype=torch.bfloat16) # Test `__init__` method with device set to None def test_init_device_none(): with patch( - "torch.cuda.is_available", - return_value=False, + "torch.cuda.is_available", + return_value=False, ): instance = Idefics(device=None) assert instance.device == "cpu" @@ -200,8 +178,8 @@ def test_init_device_none(): # Test `__init__` method with device set to "cuda" def test_init_device_cuda(): with patch( - "torch.cuda.is_available", - return_value=True, + "torch.cuda.is_available", + return_value=True, ): instance = Idefics(device="cuda") assert instance.device == "cuda" diff --git a/tests/models/test_kosmos.py b/tests/models/test_kosmos.py index 1219f895..ad3718ca 100644 --- a/tests/models/test_kosmos.py +++ b/tests/models/test_kosmos.py @@ -16,9 +16,7 @@ def mock_image_request(): img_data = open(TEST_IMAGE_URL, "rb").read() mock_resp = Mock() mock_resp.raw = img_data - with patch.object( - requests, "get", return_value=mock_resp - ) as _fixture: + with patch.object(requests, "get", return_value=mock_resp) as _fixture: yield _fixture @@ -47,18 +45,16 @@ def test_get_image(mock_image_request): # Test multimodal grounding def test_multimodal_grounding(mock_image_request): kosmos = Kosmos() - kosmos.multimodal_grounding( - "Find the red apple in the image.", TEST_IMAGE_URL - ) + kosmos.multimodal_grounding("Find the red apple in the image.", + TEST_IMAGE_URL) # TODO: Validate the result if possible # Test referring expression comprehension def test_referring_expression_comprehension(mock_image_request): kosmos = Kosmos() - kosmos.referring_expression_comprehension( - "Show me the green bottle.", TEST_IMAGE_URL - ) + kosmos.referring_expression_comprehension("Show me the green bottle.", + TEST_IMAGE_URL) # TODO: Validate the result if possible @@ -93,6 +89,7 @@ IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=fo # Mock response for requests.get() class MockResponse: + @staticmethod def json(): return {} @@ -111,30 +108,23 @@ def kosmos(): # Mocking the requests.get() method @pytest.fixture def mock_request_get(monkeypatch): - monkeypatch.setattr( - requests, "get", lambda url, **kwargs: MockResponse() - ) + monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse()) @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding(kosmos): - kosmos.multimodal_grounding( - "Find the red apple in the image.", IMG_URL1 - ) + kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1) @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension(kosmos): - kosmos.referring_expression_comprehension( - "Show me the green bottle.", IMG_URL2 - ) + kosmos.referring_expression_comprehension("Show me the green bottle.", + IMG_URL2) @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_generation(kosmos): - kosmos.referring_expression_generation( - "It is on the table.", IMG_URL3 - ) + kosmos.referring_expression_generation("It is on the table.", IMG_URL3) @pytest.mark.usefixtures("mock_request_get") @@ -154,16 +144,13 @@ def test_grounded_image_captioning_detailed(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding_2(kosmos): - kosmos.multimodal_grounding( - "Find the yellow fruit in the image.", IMG_URL2 - ) + kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2) @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension_2(kosmos): - kosmos.referring_expression_comprehension( - "Where is the water bottle?", IMG_URL3 - ) + kosmos.referring_expression_comprehension("Where is the water bottle?", + IMG_URL3) @pytest.mark.usefixtures("mock_request_get") diff --git a/tests/models/test_llama_function_caller.py b/tests/models/test_llama_function_caller.py index 1e9df654..f7afb90c 100644 --- a/tests/models/test_llama_function_caller.py +++ b/tests/models/test_llama_function_caller.py @@ -18,6 +18,7 @@ def test_llama_model_loading(llama_caller): # Test adding and calling custom functions def test_llama_custom_function(llama_caller): + def sample_function(arg1, arg2): return f"Sample function called with args: {arg1}, {arg2}" @@ -39,13 +40,11 @@ def test_llama_custom_function(llama_caller): ], ) - result = llama_caller.call_function( - "sample_function", arg1="arg1_value", arg2="arg2_value" - ) + result = llama_caller.call_function("sample_function", + arg1="arg1_value", + arg2="arg2_value") assert ( - result - == "Sample function called with args: arg1_value, arg2_value" - ) + result == "Sample function called with args: arg1_value, arg2_value") # Test streaming user prompts @@ -64,6 +63,7 @@ def test_llama_custom_function_not_found(llama_caller): # Test invalid arguments for custom function def test_llama_custom_function_invalid_arguments(llama_caller): + def sample_function(arg1, arg2): return f"Sample function called with args: {arg1}, {arg2}" @@ -86,9 +86,7 @@ def test_llama_custom_function_invalid_arguments(llama_caller): ) with pytest.raises(TypeError): - llama_caller.call_function( - "sample_function", arg1="arg1_value" - ) + llama_caller.call_function("sample_function", arg1="arg1_value") # Test streaming with custom runtime diff --git a/tests/models/test_mixtral.py b/tests/models/test_mixtral.py index a68a9026..ce26a777 100644 --- a/tests/models/test_mixtral.py +++ b/tests/models/test_mixtral.py @@ -21,22 +21,18 @@ def test_mixtral_run(mock_model, mock_tokenizer): mixtral = Mixtral() mock_tokenizer_instance = MagicMock() mock_model_instance = MagicMock() - mock_tokenizer.from_pretrained.return_value = ( - mock_tokenizer_instance - ) + mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance) mock_model.from_pretrained.return_value = mock_model_instance mock_tokenizer_instance.return_tensors = "pt" mock_model_instance.generate.return_value = [101, 102, 103] mock_tokenizer_instance.decode.return_value = "Generated text" result = mixtral.run("Test task") assert result == "Generated text" - mock_tokenizer_instance.assert_called_once_with( - "Test task", return_tensors="pt" - ) + mock_tokenizer_instance.assert_called_once_with("Test task", + return_tensors="pt") mock_model_instance.generate.assert_called_once() mock_tokenizer_instance.decode.assert_called_once_with( - [101, 102, 103], skip_special_tokens=True - ) + [101, 102, 103], skip_special_tokens=True) @patch("swarms.models.mixtral.AutoTokenizer") @@ -45,9 +41,7 @@ def test_mixtral_run_error(mock_model, mock_tokenizer): mixtral = Mixtral() mock_tokenizer_instance = MagicMock() mock_model_instance = MagicMock() - mock_tokenizer.from_pretrained.return_value = ( - mock_tokenizer_instance - ) + mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance) mock_model.from_pretrained.return_value = mock_model_instance mock_tokenizer_instance.return_tensors = "pt" mock_model_instance.generate.side_effect = Exception("Test error") diff --git a/tests/models/test_mpt7b.py b/tests/models/test_mpt7b.py index 92b6c254..e0b49624 100644 --- a/tests/models/test_mpt7b.py +++ b/tests/models/test_mpt7b.py @@ -25,14 +25,10 @@ def test_mpt7b_run(): "EleutherAI/gpt-neox-20b", max_tokens=150, ) - output = mpt.run( - "generate", "Once upon a time in a land far, far away..." - ) + output = mpt.run("generate", "Once upon a time in a land far, far away...") assert isinstance(output, str) - assert output.startswith( - "Once upon a time in a land far, far away..." - ) + assert output.startswith("Once upon a time in a land far, far away...") def test_mpt7b_run_invalid_task(): @@ -55,14 +51,10 @@ def test_mpt7b_generate(): "EleutherAI/gpt-neox-20b", max_tokens=150, ) - output = mpt.generate( - "Once upon a time in a land far, far away..." - ) + output = mpt.generate("Once upon a time in a land far, far away...") assert isinstance(output, str) - assert output.startswith( - "Once upon a time in a land far, far away..." - ) + assert output.startswith("Once upon a time in a land far, far away...") def test_mpt7b_batch_generate(): diff --git a/tests/models/test_nougat.py b/tests/models/test_nougat.py index 858845a6..2790fd77 100644 --- a/tests/models/test_nougat.py +++ b/tests/models/test_nougat.py @@ -43,9 +43,7 @@ def test_model_initialization(setup_nougat): "cuda_available, expected_device", [(True, "cuda"), (False, "cpu")], ) -def test_device_initialization( - cuda_available, expected_device, monkeypatch -): +def test_device_initialization(cuda_available, expected_device, monkeypatch): monkeypatch.setattr( torch, "cuda", @@ -74,9 +72,7 @@ def test_get_image_invalid_path(setup_nougat): (10, 50), ], ) -def test_model_call_with_diff_params( - setup_nougat, min_len, max_tokens -): +def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens): setup_nougat.min_length = min_len setup_nougat.max_new_tokens = max_tokens @@ -107,11 +103,11 @@ def test_model_call_mocked_output(setup_nougat): def mock_processor_and_model(): """Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior.""" with patch( - "transformers.NougatProcessor.from_pretrained", - return_value=Mock(), + "transformers.NougatProcessor.from_pretrained", + return_value=Mock(), ), patch( - "transformers.VisionEncoderDecoderModel.from_pretrained", - return_value=Mock(), + "transformers.VisionEncoderDecoderModel.from_pretrained", + return_value=Mock(), ): yield @@ -122,8 +118,7 @@ def test_nougat_with_sample_image_1(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) + )) assert isinstance(result, str) @@ -140,8 +135,7 @@ def test_nougat_min_length_param(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) + )) assert isinstance(result, str) @@ -152,8 +146,7 @@ def test_nougat_max_new_tokens_param(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) + )) assert isinstance(result, str) @@ -164,16 +157,13 @@ def test_nougat_different_model_path(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) + )) assert isinstance(result, str) @pytest.mark.usefixtures("mock_processor_and_model") def test_nougat_bad_image_path(setup_nougat): - with pytest.raises( - Exception - ): # Adjust the exception type accordingly. + with pytest.raises(Exception): # Adjust the exception type accordingly. setup_nougat("bad_image_path.png") @@ -183,8 +173,7 @@ def test_nougat_image_large_size(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", - ) - ) + )) assert isinstance(result, str) @@ -194,8 +183,7 @@ def test_nougat_image_small_size(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", - ) - ) + )) assert isinstance(result, str) @@ -205,8 +193,7 @@ def test_nougat_image_varied_content(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", - ) - ) + )) assert isinstance(result, str) @@ -216,6 +203,5 @@ def test_nougat_image_with_metadata(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", - ) - ) + )) assert isinstance(result, str) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index a920256c..1c6442b5 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -4,18 +4,16 @@ from swarms.models.qwen import QwenVLMultiModal def test_post_init(): - with patch( - "swarms.models.qwen.AutoTokenizer.from_pretrained" - ) as mock_tokenizer, patch( - "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" - ) as mock_model: + with patch("swarms.models.qwen.AutoTokenizer.from_pretrained" + ) as mock_tokenizer, patch( + "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" + ) as mock_model: mock_tokenizer.return_value = Mock() mock_model.return_value = Mock() model = QwenVLMultiModal() - mock_tokenizer.assert_called_once_with( - model.model_name, trust_remote_code=True - ) + mock_tokenizer.assert_called_once_with(model.model_name, + trust_remote_code=True) mock_model.assert_called_once_with( model.model_name, device_map=model.device, @@ -25,37 +23,31 @@ def test_post_init(): def test_run(): with patch( - "swarms.models.qwen.AutoTokenizer.from_list_format" + "swarms.models.qwen.AutoTokenizer.from_list_format" ) as mock_format, patch( - "swarms.models.qwen.AutoTokenizer.__call__" - ) as mock_call, patch( - "swarms.models.qwen.AutoModelForCausalLM.generate" - ) as mock_generate, patch( - "swarms.models.qwen.AutoTokenizer.decode" - ) as mock_decode: + "swarms.models.qwen.AutoTokenizer.__call__") as mock_call, patch( + "swarms.models.qwen.AutoModelForCausalLM.generate" + ) as mock_generate, patch( + "swarms.models.qwen.AutoTokenizer.decode") as mock_decode: mock_format.return_value = Mock() mock_call.return_value = Mock() mock_generate.return_value = Mock() mock_decode.return_value = "response" model = QwenVLMultiModal() - response = model.run( - "Hello, how are you?", "https://example.com/image.jpg" - ) + response = model.run("Hello, how are you?", + "https://example.com/image.jpg") assert response == "response" def test_chat(): - with patch( - "swarms.models.qwen.AutoModelForCausalLM.chat" - ) as mock_chat: + with patch("swarms.models.qwen.AutoModelForCausalLM.chat") as mock_chat: mock_chat.return_value = ("response", ["history"]) model = QwenVLMultiModal() - response, history = model.chat( - "Hello, how are you?", "https://example.com/image.jpg" - ) + response, history = model.chat("Hello, how are you?", + "https://example.com/image.jpg") assert response == "response" assert history == ["history"] diff --git a/tests/models/test_speech_t5.py b/tests/models/test_speech_t5.py index d32c21db..0fc61e6b 100644 --- a/tests/models/test_speech_t5.py +++ b/tests/models/test_speech_t5.py @@ -16,16 +16,11 @@ def speecht5_model(): def test_speecht5_init(speecht5_model): - assert isinstance( - speecht5_model.processor, SpeechT5.processor.__class__ - ) + assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) assert isinstance(speecht5_model.model, SpeechT5.model.__class__) - assert isinstance( - speecht5_model.vocoder, SpeechT5.vocoder.__class__ - ) - assert isinstance( - speecht5_model.embeddings_dataset, torch.utils.data.Dataset - ) + assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) + assert isinstance(speecht5_model.embeddings_dataset, + torch.utils.data.Dataset) def test_speecht5_call(speecht5_model): @@ -49,10 +44,7 @@ def test_speecht5_set_model(speecht5_model): speecht5_model.set_model(new_model_name) assert speecht5_model.model_name == new_model_name assert speecht5_model.processor.model_name == new_model_name - assert ( - speecht5_model.model.config.model_name_or_path - == new_model_name - ) + assert (speecht5_model.model.config.model_name_or_path == new_model_name) speecht5_model.set_model(old_model_name) # Restore original model @@ -62,12 +54,8 @@ def test_speecht5_set_vocoder(speecht5_model): speecht5_model.set_vocoder(new_vocoder_name) assert speecht5_model.vocoder_name == new_vocoder_name assert ( - speecht5_model.vocoder.config.model_name_or_path - == new_vocoder_name - ) - speecht5_model.set_vocoder( - old_vocoder_name - ) # Restore original vocoder + speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name) + speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder def test_speecht5_set_embeddings_dataset(speecht5_model): @@ -75,12 +63,10 @@ def test_speecht5_set_embeddings_dataset(speecht5_model): new_dataset_name = "Matthijs/cmu-arctic-xvectors-test" speecht5_model.set_embeddings_dataset(new_dataset_name) assert speecht5_model.dataset_name == new_dataset_name - assert isinstance( - speecht5_model.embeddings_dataset, torch.utils.data.Dataset - ) + assert isinstance(speecht5_model.embeddings_dataset, + torch.utils.data.Dataset) speecht5_model.set_embeddings_dataset( - old_dataset_name - ) # Restore original dataset + old_dataset_name) # Restore original dataset def test_speecht5_get_sampling_rate(speecht5_model): @@ -112,9 +98,7 @@ def test_speecht5_change_dataset_split(speecht5_model): def test_speecht5_load_custom_embedding(speecht5_model): xvector = [0.1, 0.2, 0.3, 0.4, 0.5] embedding = speecht5_model.load_custom_embedding(xvector) - assert torch.all( - torch.eq(embedding, torch.tensor(xvector).unsqueeze(0)) - ) + assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))) def test_speecht5_with_different_speakers(speecht5_model): @@ -125,9 +109,7 @@ def test_speecht5_with_different_speakers(speecht5_model): assert isinstance(speech, torch.Tensor) -def test_speecht5_save_speech_with_different_extensions( - speecht5_model, -): +def test_speecht5_save_speech_with_different_extensions(speecht5_model,): text = "Hello, how are you?" speech = speecht5_model(text) extensions = [".wav", ".flac"] @@ -162,6 +144,4 @@ def test_speecht5_change_vocoder_model(speecht5_model): speecht5_model.set_vocoder(new_vocoder_name) speech = speecht5_model(text) assert isinstance(speech, torch.Tensor) - speecht5_model.set_vocoder( - old_vocoder_name - ) # Restore original vocoder + speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder diff --git a/tests/models/test_ssd_1b.py b/tests/models/test_ssd_1b.py index f658f853..b9b5bb25 100644 --- a/tests/models/test_ssd_1b.py +++ b/tests/models/test_ssd_1b.py @@ -21,36 +21,30 @@ def test_ssd1b_call(ssd1b_model): image_url = ssd1b_model(task, neg_prompt) assert isinstance(image_url, str) assert image_url.startswith( - "https://" - ) # Assuming it starts with "https://" + "https://") # Assuming it starts with "https://" # Add more tests for various aspects of the class and methods # Example of a parameterized test for different tasks -@pytest.mark.parametrize( - "task", ["A painting of a cat", "A painting of a tree"] -) +@pytest.mark.parametrize("task", + ["A painting of a cat", "A painting of a tree"]) def test_ssd1b_parameterized_task(ssd1b_model, task): image_url = ssd1b_model(task) assert isinstance(image_url, str) assert image_url.startswith( - "https://" - ) # Assuming it starts with "https://" + "https://") # Assuming it starts with "https://" # Example of a test using mocks to isolate units of code def test_ssd1b_with_mock(ssd1b_model, mocker): - mocker.patch( - "your_module.StableDiffusionXLPipeline" - ) # Mock the pipeline + mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline task = "A painting of a cat" image_url = ssd1b_model(task) assert isinstance(image_url, str) assert image_url.startswith( - "https://" - ) # Assuming it starts with "https://" + "https://") # Assuming it starts with "https://" def test_ssd1b_call_with_cache(ssd1b_model): @@ -68,9 +62,8 @@ def test_ssd1b_invalid_task(ssd1b_model): def test_ssd1b_failed_api_call(ssd1b_model, mocker): - mocker.patch( - "your_module.StableDiffusionXLPipeline" - ) # Mock the pipeline to raise an exception + mocker.patch("your_module.StableDiffusionXLPipeline" + ) # Mock the pipeline to raise an exception task = "A painting of a cat" with pytest.raises(Exception): ssd1b_model(task) diff --git a/tests/models/test_timm.py b/tests/models/test_timm.py index 4af689e5..f4dc8dc2 100644 --- a/tests/models/test_timm.py +++ b/tests/models/test_timm.py @@ -19,18 +19,16 @@ def test_timm_model_init(): def test_timm_model_call(): - with patch( - "swarms.models.timm.create_model" - ) as mock_create_model: + with patch("swarms.models.timm.create_model") as mock_create_model: model_name = "resnet18" pretrained = True in_chans = 3 timm_model = TimmModel(model_name, pretrained, in_chans) task = torch.rand(1, in_chans, 224, 224) result = timm_model(task) - mock_create_model.assert_called_once_with( - model_name, pretrained=pretrained, in_chans=in_chans - ) + mock_create_model.assert_called_once_with(model_name, + pretrained=pretrained, + in_chans=in_chans) assert result == mock_create_model.return_value(task) diff --git a/tests/models/test_timm_model.py b/tests/models/test_timm_model.py index b2f8f6c9..06ec2a93 100644 --- a/tests/models/test_timm_model.py +++ b/tests/models/test_timm_model.py @@ -22,17 +22,14 @@ def test_create_model(sample_model_info): def test_call(sample_model_info): model_handler = TimmModel() input_tensor = torch.randn(1, 3, 224, 224) - output_shape = model_handler.__call__( - sample_model_info, input_tensor - ) + output_shape = model_handler.__call__(sample_model_info, input_tensor) assert isinstance(output_shape, torch.Size) def test_get_supported_models_mock(): model_handler = TimmModel() model_handler._get_supported_models = Mock( - return_value=["resnet18", "resnet50"] - ) + return_value=["resnet18", "resnet50"]) supported_models = model_handler._get_supported_models() assert supported_models == ["resnet18", "resnet50"] diff --git a/tests/models/test_togther.py b/tests/models/test_togther.py index dd2a2f89..4b14acc3 100644 --- a/tests/models/test_togther.py +++ b/tests/models/test_togther.py @@ -55,7 +55,11 @@ def test_init_custom_params(mock_api_key): def test_run_success(mock_post, mock_api_key): mock_response = Mock() mock_response.json.return_value = { - "choices": [{"message": {"content": "Generated response"}}] + "choices": [{ + "message": { + "content": "Generated response" + } + }] } mock_post.return_value = mock_response @@ -69,8 +73,7 @@ def test_run_success(mock_post, mock_api_key): @patch("swarms.models.together_model.requests.post") def test_run_failure(mock_post, mock_api_key): mock_post.side_effect = requests.exceptions.RequestException( - "Request failed" - ) + "Request failed") model = TogetherLLM() task = "What is the color of the object?" @@ -89,9 +92,7 @@ def test_run_with_logging_enabled(caplog, mock_api_key): assert "Sending request to" in caplog.text -@pytest.mark.parametrize( - "invalid_input", [None, 123, ["list", "of", "items"]] -) +@pytest.mark.parametrize("invalid_input", [None, 123, ["list", "of", "items"]]) def test_invalid_task_input(invalid_input, mock_api_key): model = TogetherLLM() response = model.run(invalid_input) @@ -103,7 +104,11 @@ def test_invalid_task_input(invalid_input, mock_api_key): def test_run_streaming_enabled(mock_post, mock_api_key): mock_response = Mock() mock_response.json.return_value = { - "choices": [{"message": {"content": "Generated response"}}] + "choices": [{ + "message": { + "content": "Generated response" + } + }] } mock_post.return_value = mock_response diff --git a/tests/models/test_ultralytics.py b/tests/models/test_ultralytics.py index cca1d023..30aa64fb 100644 --- a/tests/models/test_ultralytics.py +++ b/tests/models/test_ultralytics.py @@ -20,9 +20,7 @@ def test_ultralytics_call(): args = (1, 2, 3) kwargs = {"a": "A", "b": "B"} result = ultralytics(task, *args, **kwargs) - mock_yolo.return_value.assert_called_once_with( - task, *args, **kwargs - ) + mock_yolo.return_value.assert_called_once_with(task, *args, **kwargs) assert result == mock_yolo.return_value.return_value diff --git a/tests/models/test_vilt.py b/tests/models/test_vilt.py index d849f98e..dd1e47e2 100644 --- a/tests/models/test_vilt.py +++ b/tests/models/test_vilt.py @@ -21,17 +21,13 @@ def test_vilt_initialization(vilt_instance): # 2. Test Model Predictions @patch.object(requests, "get") @patch.object(Image, "open") -def test_vilt_prediction( - mock_image_open, mock_requests_get, vilt_instance -): +def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): mock_image = Mock() mock_image_open.return_value = mock_image mock_requests_get.return_value.raw = Mock() # It's a mock response, so no real answer expected - with pytest.raises( - Exception - ): # Ensure exception is more specific + with pytest.raises(Exception): # Ensure exception is more specific vilt_instance( "What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", @@ -66,9 +62,7 @@ def test_vilt_network_exception(vilt_instance): ], ) def test_vilt_various_inputs(text, image_url, vilt_instance): - with pytest.raises( - Exception - ): # Again, ensure exception is more specific + with pytest.raises(Exception): # Again, ensure exception is more specific vilt_instance(text, image_url) diff --git a/tests/models/test_yi_200k.py b/tests/models/test_yi_200k.py index b31daa3e..9b2741a2 100644 --- a/tests/models/test_yi_200k.py +++ b/tests/models/test_yi_200k.py @@ -32,9 +32,7 @@ def test_yi34b_generate_text_with_length(yi34b_model, max_length): @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_yi34b_generate_text_with_temperature( - yi34b_model, temperature -): +def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): prompt = "There's a place where time stands still." generated_text = yi34b_model(prompt, temperature=temperature) assert isinstance(generated_text, str) @@ -42,27 +40,24 @@ def test_yi34b_generate_text_with_temperature( def test_yi34b_generate_text_with_invalid_prompt(yi34b_model): prompt = None # Invalid prompt - with pytest.raises( - ValueError, match="Input prompt must be a non-empty string" - ): + with pytest.raises(ValueError, + match="Input prompt must be a non-empty string"): yi34b_model(prompt) def test_yi34b_generate_text_with_invalid_max_length(yi34b_model): prompt = "There's a place where time stands still." max_length = -1 # Invalid max_length - with pytest.raises( - ValueError, match="max_length must be a positive integer" - ): + with pytest.raises(ValueError, + match="max_length must be a positive integer"): yi34b_model(prompt, max_length=max_length) def test_yi34b_generate_text_with_invalid_temperature(yi34b_model): prompt = "There's a place where time stands still." temperature = 2.0 # Invalid temperature - with pytest.raises( - ValueError, match="temperature must be between 0.01 and 1.0" - ): + with pytest.raises(ValueError, + match="temperature must be between 0.01 and 1.0"): yi34b_model(prompt, temperature=temperature) @@ -83,40 +78,32 @@ def test_yi34b_generate_text_with_top_p(yi34b_model, top_p): def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): prompt = "There's a place where time stands still." top_k = -1 # Invalid top_k - with pytest.raises( - ValueError, match="top_k must be a non-negative integer" - ): + with pytest.raises(ValueError, + match="top_k must be a non-negative integer"): yi34b_model(prompt, top_k=top_k) def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): prompt = "There's a place where time stands still." top_p = 1.5 # Invalid top_p - with pytest.raises( - ValueError, match="top_p must be between 0.0 and 1.0" - ): + with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"): yi34b_model(prompt, top_p=top_p) @pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5]) -def test_yi34b_generate_text_with_repitition_penalty( - yi34b_model, repitition_penalty -): +def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, + repitition_penalty): prompt = "There's a place where time stands still." - generated_text = yi34b_model( - prompt, repitition_penalty=repitition_penalty - ) + generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) assert isinstance(generated_text, str) -def test_yi34b_generate_text_with_invalid_repitition_penalty( - yi34b_model, -): +def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model,): prompt = "There's a place where time stands still." repitition_penalty = 0.0 # Invalid repitition_penalty with pytest.raises( - ValueError, - match="repitition_penalty must be a positive float", + ValueError, + match="repitition_penalty must be a positive float", ): yi34b_model(prompt, repitition_penalty=repitition_penalty) diff --git a/tests/models/test_zeroscope.py b/tests/models/test_zeroscope.py index 25a4c597..f0b130d1 100644 --- a/tests/models/test_zeroscope.py +++ b/tests/models/test_zeroscope.py @@ -25,16 +25,11 @@ def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) + mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) + mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance.enable_vae_slicing.assert_called_once() mock_pipeline_instance.enable_forward_chunking.assert_called_once_with( - chunk_size=1, dim=1 - ) + chunk_size=1, dim=1) result = zeroscope.forward("Test task") assert result == "Generated frames" mock_pipeline_instance.assert_called_once_with( @@ -51,12 +46,8 @@ def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) + mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) + mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance.side_effect = Exception("Test error") with pytest.raises(Exception, match="Test error"): zeroscope.forward("Test task") @@ -67,12 +58,8 @@ def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) + mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) + mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") result = zeroscope.__call__("Test task") assert result == "Generated frames" mock_pipeline_instance.assert_called_once_with( @@ -89,12 +76,8 @@ def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) + mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) + mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance.side_effect = Exception("Test error") with pytest.raises(Exception, match="Test error"): zeroscope.__call__("Test task") @@ -105,12 +88,8 @@ def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) + mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) + mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") result = zeroscope.save_video_path("Test video path") assert result == "Test video path" mock_pipeline_instance.assert_called_once_with( diff --git a/tests/structs/test_agent.py b/tests/structs/test_agent.py index 5be7f31a..bfb3b365 100644 --- a/tests/structs/test_agent.py +++ b/tests/structs/test_agent.py @@ -18,9 +18,7 @@ openai_api_key = os.getenv("OPENAI_API_KEY") # Mocks and Fixtures @pytest.fixture def mocked_llm(): - return OpenAIChat( - openai_api_key=openai_api_key, - ) + return OpenAIChat(openai_api_key=openai_api_key,) @pytest.fixture @@ -65,15 +63,12 @@ def test_provide_feedback(basic_flow): @patch("time.sleep", return_value=None) # to speed up tests def test_run_without_stopping_condition(mocked_sleep, basic_flow): response = basic_flow.run("Test task") - assert ( - response == "Test task" - ) # since our mocked llm doesn't modify the response + assert (response == "Test task" + ) # since our mocked llm doesn't modify the response @patch("time.sleep", return_value=None) # to speed up tests -def test_run_with_stopping_condition( - mocked_sleep, flow_with_condition -): +def test_run_with_stopping_condition(mocked_sleep, flow_with_condition): response = flow_with_condition.run("Stop") assert response == "Stop" @@ -113,6 +108,7 @@ def test_env_variable_handling(monkeypatch): # Test initializing the agent with different stopping conditions def test_flow_with_custom_stopping_condition(mocked_llm): + def stopping_condition(x): return "terminate" in x.lower() @@ -133,9 +129,7 @@ def test_flow_call(basic_flow): # Test formatting the prompt def test_format_prompt(basic_flow): - formatted_prompt = basic_flow.format_prompt( - "Hello {name}", name="John" - ) + formatted_prompt = basic_flow.format_prompt("Hello {name}", name="John") assert formatted_prompt == "Hello John" @@ -164,9 +158,15 @@ def test_interactive_mode(basic_flow): # Test bulk run with varied inputs def test_bulk_run_varied_inputs(basic_flow): inputs = [ - {"task": "Test1"}, - {"task": "Test2"}, - {"task": "Stop now"}, + { + "task": "Test1" + }, + { + "task": "Test2" + }, + { + "task": "Stop now" + }, ] responses = basic_flow.bulk_run(inputs) assert responses == ["Test1", "Test2", "Stop now"] @@ -191,12 +191,9 @@ def test_save_different_memory(basic_flow, tmp_path): # Test the stopping condition check def test_check_stopping_condition(flow_with_condition): - assert flow_with_condition._check_stopping_condition( - "Stop this process" - ) + assert flow_with_condition._check_stopping_condition("Stop this process") assert not flow_with_condition._check_stopping_condition( - "Continue the task" - ) + "Continue the task") # Test without providing max loops (default value should be 5) @@ -252,9 +249,7 @@ def test_different_retry_intervals(mocked_sleep, basic_flow): # Test invoking the agent with additional kwargs @patch("time.sleep", return_value=None) def test_flow_call_with_kwargs(mocked_sleep, basic_flow): - response = basic_flow( - "Test call", param1="value1", param2="value2" - ) + response = basic_flow("Test call", param1="value1", param2="value2") assert response == "Test call" @@ -289,9 +284,7 @@ def test_stopping_token_in_response(mocked_sleep, basic_flow): def flow_instance(): # Create an instance of the Agent class with required parameters for testing # You may need to adjust this based on your actual class initialization - llm = OpenAIChat( - openai_api_key=openai_api_key, - ) + llm = OpenAIChat(openai_api_key=openai_api_key,) agent = Agent( llm=llm, max_loops=5, @@ -338,9 +331,7 @@ def test_flow_autosave(flow_instance): def test_flow_response_filtering(flow_instance): # Test the response filtering functionality flow_instance.add_response_filter("filter_this") - response = flow_instance.filtered_run( - "This message should filter_this" - ) + response = flow_instance.filtered_run("This message should filter_this") assert "filter_this" not in response @@ -400,11 +391,8 @@ def test_flow_response_length(flow_instance): # Test checking the length of the response response = flow_instance.run( "Generate a 10,000 word long blog on mental clarity and the" - " benefits of meditation." - ) - assert ( - len(response) > flow_instance.get_response_length_threshold() - ) + " benefits of meditation.") + assert (len(response) > flow_instance.get_response_length_threshold()) def test_flow_set_response_length_threshold(flow_instance): @@ -493,9 +481,7 @@ def test_flow_get_conversation_log(flow_instance): flow_instance.run("Message 1") flow_instance.run("Message 2") conversation_log = flow_instance.get_conversation_log() - assert ( - len(conversation_log) == 4 - ) # Including system and user messages + assert (len(conversation_log) == 4) # Including system and user messages def test_flow_clear_conversation_log(flow_instance): @@ -579,37 +565,18 @@ def test_flow_rollback(flow_instance): flow_instance.change_prompt("New prompt") flow_instance.get_state() flow_instance.rollback_to_state(state1) - assert ( - flow_instance.get_current_prompt() == state1["current_prompt"] - ) + assert (flow_instance.get_current_prompt() == state1["current_prompt"]) assert flow_instance.get_instructions() == state1["instructions"] - assert ( - flow_instance.get_user_messages() == state1["user_messages"] - ) - assert ( - flow_instance.get_response_history() - == state1["response_history"] - ) - assert ( - flow_instance.get_conversation_log() - == state1["conversation_log"] - ) - assert ( - flow_instance.is_dynamic_pacing_enabled() - == state1["dynamic_pacing_enabled"] - ) - assert ( - flow_instance.get_response_length_threshold() - == state1["response_length_threshold"] - ) - assert ( - flow_instance.get_response_filters() - == state1["response_filters"] - ) + assert (flow_instance.get_user_messages() == state1["user_messages"]) + assert (flow_instance.get_response_history() == state1["response_history"]) + assert (flow_instance.get_conversation_log() == state1["conversation_log"]) + assert (flow_instance.is_dynamic_pacing_enabled() == + state1["dynamic_pacing_enabled"]) + assert (flow_instance.get_response_length_threshold() == + state1["response_length_threshold"]) + assert (flow_instance.get_response_filters() == state1["response_filters"]) assert flow_instance.get_max_loops() == state1["max_loops"] - assert ( - flow_instance.get_autosave_path() == state1["autosave_path"] - ) + assert (flow_instance.get_autosave_path() == state1["autosave_path"]) assert flow_instance.get_state() == state1 @@ -618,8 +585,7 @@ def test_flow_contextual_intent(flow_instance): flow_instance.add_context("location", "New York") flow_instance.add_context("time", "tomorrow") response = flow_instance.run( - "What's the weather like in {location} at {time}?" - ) + "What's the weather like in {location} at {time}?") assert "New York" in response assert "tomorrow" in response @@ -627,13 +593,9 @@ def test_flow_contextual_intent(flow_instance): def test_flow_contextual_intent_override(flow_instance): # Test contextual intent override flow_instance.add_context("location", "New York") - response1 = flow_instance.run( - "What's the weather like in {location}?" - ) + response1 = flow_instance.run("What's the weather like in {location}?") flow_instance.add_context("location", "Los Angeles") - response2 = flow_instance.run( - "What's the weather like in {location}?" - ) + response2 = flow_instance.run("What's the weather like in {location}?") assert "New York" in response1 assert "Los Angeles" in response2 @@ -641,13 +603,9 @@ def test_flow_contextual_intent_override(flow_instance): def test_flow_contextual_intent_reset(flow_instance): # Test resetting contextual intent flow_instance.add_context("location", "New York") - response1 = flow_instance.run( - "What's the weather like in {location}?" - ) + response1 = flow_instance.run("What's the weather like in {location}?") flow_instance.reset_context() - response2 = flow_instance.run( - "What's the weather like in {location}?" - ) + response2 = flow_instance.run("What's the weather like in {location}?") assert "New York" in response1 assert "New York" in response2 @@ -672,9 +630,7 @@ def test_flow_non_interruptible(flow_instance): def test_flow_timeout(flow_instance): # Test conversation timeout flow_instance.timeout = 60 # Set a timeout of 60 seconds - response = flow_instance.run( - "This should take some time to respond." - ) + response = flow_instance.run("This should take some time to respond.") assert "Timed out" in response assert flow_instance.is_timed_out() is True @@ -723,20 +679,14 @@ def test_flow_save_and_load_conversation(flow_instance): def test_flow_inject_custom_system_message(flow_instance): # Test injecting a custom system message into the conversation - flow_instance.inject_custom_system_message( - "Custom system message" - ) - assert ( - "Custom system message" in flow_instance.get_message_history() - ) + flow_instance.inject_custom_system_message("Custom system message") + assert ("Custom system message" in flow_instance.get_message_history()) def test_flow_inject_custom_user_message(flow_instance): # Test injecting a custom user message into the conversation flow_instance.inject_custom_user_message("Custom user message") - assert ( - "Custom user message" in flow_instance.get_message_history() - ) + assert ("Custom user message" in flow_instance.get_message_history()) def test_flow_inject_custom_response(flow_instance): @@ -747,45 +697,28 @@ def test_flow_inject_custom_response(flow_instance): def test_flow_clear_injected_messages(flow_instance): # Test clearing injected messages from the conversation - flow_instance.inject_custom_system_message( - "Custom system message" - ) + flow_instance.inject_custom_system_message("Custom system message") flow_instance.inject_custom_user_message("Custom user message") flow_instance.inject_custom_response("Custom response") flow_instance.clear_injected_messages() - assert ( - "Custom system message" - not in flow_instance.get_message_history() - ) - assert ( - "Custom user message" - not in flow_instance.get_message_history() - ) - assert ( - "Custom response" not in flow_instance.get_message_history() - ) + assert ("Custom system message" not in flow_instance.get_message_history()) + assert ("Custom user message" not in flow_instance.get_message_history()) + assert ("Custom response" not in flow_instance.get_message_history()) def test_flow_disable_message_history(flow_instance): # Test disabling message history recording flow_instance.disable_message_history() response = flow_instance.run( - "This message should not be recorded in history." - ) - assert ( - "This message should not be recorded in history." in response - ) - assert ( - len(flow_instance.get_message_history()) == 0 - ) # History is empty + "This message should not be recorded in history.") + assert ("This message should not be recorded in history." in response) + assert (len(flow_instance.get_message_history()) == 0) # History is empty def test_flow_enable_message_history(flow_instance): # Test enabling message history recording flow_instance.enable_message_history() - response = flow_instance.run( - "This message should be recorded in history." - ) + response = flow_instance.run("This message should be recorded in history.") assert "This message should be recorded in history." in response assert len(flow_instance.get_message_history()) == 1 @@ -795,9 +728,7 @@ def test_flow_custom_logger(flow_instance): custom_logger = logger # Replace with your custom logger class flow_instance.set_logger(custom_logger) response = flow_instance.run("Custom logger test") - assert ( - "Logged using custom logger" in response - ) # Verify logging message + assert ("Logged using custom logger" in response) # Verify logging message def test_flow_batch_processing(flow_instance): @@ -871,43 +802,35 @@ def test_flow_input_validation(flow_instance): with pytest.raises(ValueError): flow_instance.set_message_delimiter( - "" - ) # Empty delimiter, should raise ValueError + "") # Empty delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_message_delimiter( - None - ) # None delimiter, should raise ValueError + None) # None delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_message_delimiter( - 123 - ) # Invalid delimiter type, should raise ValueError + 123) # Invalid delimiter type, should raise ValueError with pytest.raises(ValueError): flow_instance.set_logger( - "invalid_logger" - ) # Invalid logger type, should raise ValueError + "invalid_logger") # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context( - None, "value" - ) # None key, should raise ValueError + flow_instance.add_context(None, + "value") # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context( - "key", None - ) # None value, should raise ValueError + flow_instance.add_context("key", + None) # None value, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - None, "value" - ) # None key, should raise ValueError + None, "value") # None key, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - "key", None - ) # None value, should raise ValueError + "key", None) # None value, should raise ValueError def test_flow_conversation_reset(flow_instance): @@ -934,6 +857,7 @@ def test_flow_conversation_persistence(flow_instance): def test_flow_custom_event_listener(flow_instance): # Test custom event listener class CustomEventListener: + def on_message_received(self, message): pass @@ -945,10 +869,10 @@ def test_flow_custom_event_listener(flow_instance): # Ensure that the custom event listener methods are called during a conversation with mock.patch.object( - custom_event_listener, "on_message_received" - ) as mock_received, mock.patch.object( - custom_event_listener, "on_response_generated" - ) as mock_response: + custom_event_listener, + "on_message_received") as mock_received, mock.patch.object( + custom_event_listener, + "on_response_generated") as mock_response: flow_instance.run("Message 1") mock_received.assert_called_once() mock_response.assert_called_once() @@ -957,6 +881,7 @@ def test_flow_custom_event_listener(flow_instance): def test_flow_multiple_event_listeners(flow_instance): # Test multiple event listeners class FirstEventListener: + def on_message_received(self, message): pass @@ -964,6 +889,7 @@ def test_flow_multiple_event_listeners(flow_instance): pass class SecondEventListener: + def on_message_received(self, message): pass @@ -977,14 +903,14 @@ def test_flow_multiple_event_listeners(flow_instance): # Ensure that both event listeners receive events during a conversation with mock.patch.object( - first_event_listener, "on_message_received" - ) as mock_first_received, mock.patch.object( - first_event_listener, "on_response_generated" - ) as mock_first_response, mock.patch.object( - second_event_listener, "on_message_received" - ) as mock_second_received, mock.patch.object( - second_event_listener, "on_response_generated" - ) as mock_second_response: + first_event_listener, + "on_message_received") as mock_first_received, mock.patch.object( + first_event_listener, "on_response_generated" + ) as mock_first_response, mock.patch.object( + second_event_listener, "on_message_received" + ) as mock_second_received, mock.patch.object( + second_event_listener, + "on_response_generated") as mock_second_response: flow_instance.run("Message 1") mock_first_received.assert_called_once() mock_first_response.assert_called_once() @@ -997,38 +923,31 @@ def test_flow_error_handling(flow_instance): # Test error handling and exceptions with pytest.raises(ValueError): flow_instance.set_message_delimiter( - "" - ) # Empty delimiter, should raise ValueError + "") # Empty delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_message_delimiter( - None - ) # None delimiter, should raise ValueError + None) # None delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_logger( - "invalid_logger" - ) # Invalid logger type, should raise ValueError + "invalid_logger") # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context( - None, "value" - ) # None key, should raise ValueError + flow_instance.add_context(None, + "value") # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context( - "key", None - ) # None value, should raise ValueError + flow_instance.add_context("key", + None) # None value, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - None, "value" - ) # None key, should raise ValueError + None, "value") # None key, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - "key", None - ) # None value, should raise ValueError + "key", None) # None value, should raise ValueError def test_flow_context_operations(flow_instance): @@ -1065,14 +984,8 @@ def test_flow_custom_response(flow_instance): flow_instance.set_response_generator(custom_response_generator) assert flow_instance.run("Hello") == "Hi there!" - assert ( - flow_instance.run("How are you?") - == "I'm doing well, thank you." - ) - assert ( - flow_instance.run("What's your name?") - == "I don't understand." - ) + assert (flow_instance.run("How are you?") == "I'm doing well, thank you.") + assert (flow_instance.run("What's your name?") == "I don't understand.") def test_flow_message_validation(flow_instance): @@ -1083,12 +996,8 @@ def test_flow_message_validation(flow_instance): flow_instance.set_message_validator(custom_message_validator) assert flow_instance.run("Valid message") is not None - assert ( - flow_instance.run("") is None - ) # Empty message should be rejected - assert ( - flow_instance.run(None) is None - ) # None message should be rejected + assert (flow_instance.run("") is None) # Empty message should be rejected + assert (flow_instance.run(None) is None) # None message should be rejected def test_flow_custom_logging(flow_instance): @@ -1113,15 +1022,10 @@ def test_flow_complex_use_case(flow_instance): flow_instance.add_context("user_id", "12345") flow_instance.run("Hello") flow_instance.run("How can I help you?") - assert ( - flow_instance.get_response() == "Please provide more details." - ) + assert (flow_instance.get_response() == "Please provide more details.") flow_instance.update_context("user_id", "54321") flow_instance.run("I need help with my order") - assert ( - flow_instance.get_response() - == "Sure, I can assist with that." - ) + assert (flow_instance.get_response() == "Sure, I can assist with that.") flow_instance.reset_conversation() assert len(flow_instance.get_message_history()) == 0 assert flow_instance.get_context("user_id") is None @@ -1160,9 +1064,7 @@ def test_flow_concurrent_requests(flow_instance): def test_flow_custom_timeout(flow_instance): # Test custom timeout handling - flow_instance.set_timeout( - 10 - ) # Set a custom timeout of 10 seconds + flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds assert flow_instance.get_timeout() == 10 import time @@ -1213,16 +1115,10 @@ def test_flow_agent_history_prompt(flow_instance): history = ["User: Hi", "AI: Hello"] agent_history_prompt = flow_instance.agent_history_prompt( - system_prompt, history - ) + system_prompt, history) - assert ( - "SYSTEM_PROMPT: This is the system prompt." - in agent_history_prompt - ) - assert ( - "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt - ) + assert ("SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt) + assert ("History: ['User: Hi', 'AI: Hello']" in agent_history_prompt) async def test_flow_run_concurrent(flow_instance): @@ -1237,9 +1133,18 @@ async def test_flow_run_concurrent(flow_instance): def test_flow_bulk_run(flow_instance): # Test bulk running of tasks input_data = [ - {"task": "Task 1", "param1": "value1"}, - {"task": "Task 2", "param2": "value2"}, - {"task": "Task 3", "param3": "value3"}, + { + "task": "Task 1", + "param1": "value1" + }, + { + "task": "Task 2", + "param2": "value2" + }, + { + "task": "Task 3", + "param3": "value3" + }, ] responses = flow_instance.bulk_run(input_data) @@ -1254,9 +1159,7 @@ def test_flow_from_llm_and_template(): llm_instance = mocked_llm # Replace with your LLM class template = "This is a template for testing." - flow_instance = Agent.from_llm_and_template( - llm_instance, template - ) + flow_instance = Agent.from_llm_and_template(llm_instance, template) assert isinstance(flow_instance, Agent) @@ -1265,12 +1168,10 @@ def test_flow_from_llm_and_template_file(): # Test creating Agent instance from an LLM and a template file llm_instance = mocked_llm # Replace with your LLM class template_file = ( # Create a template file for testing - "template.txt" - ) + "template.txt") - flow_instance = Agent.from_llm_and_template_file( - llm_instance, template_file - ) + flow_instance = Agent.from_llm_and_template_file(llm_instance, + template_file) assert isinstance(flow_instance, Agent) diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py index 2e5585bf..b0d00606 100644 --- a/tests/structs/test_autoscaler.py +++ b/tests/structs/test_autoscaler.py @@ -44,9 +44,7 @@ def test_autoscaler_run(): agent.id, "Generate a 10,000 word blog on health and wellness.", ) - assert ( - out == "Generate a 10,000 word blog on health and wellness." - ) + assert (out == "Generate a 10,000 word blog on health and wellness.") def test_autoscaler_add_agent(): @@ -239,9 +237,7 @@ def test_autoscaler_add_task(): def test_autoscaler_scale_up(): - autoscaler = AutoScaler( - initial_agents=5, scale_up_factor=2, agent=agent - ) + autoscaler = AutoScaler(initial_agents=5, scale_up_factor=2, agent=agent) autoscaler.scale_up() assert len(autoscaler.agents_pool) == 10 diff --git a/tests/structs/test_base.py b/tests/structs/test_base.py index 971f966b..3cbe5c3d 100644 --- a/tests/structs/test_base.py +++ b/tests/structs/test_base.py @@ -7,6 +7,7 @@ from swarms.structs.base import BaseStructure class TestBaseStructure: + def test_init(self): base_structure = BaseStructure( name="TestStructure", @@ -88,11 +89,8 @@ class TestBaseStructure: with open(log_file) as file: lines = file.readlines() assert len(lines) == 1 - assert ( - lines[0] - == f"[{base_structure._current_timestamp()}]" - f" [{event_type}] {event}\n" - ) + assert (lines[0] == f"[{base_structure._current_timestamp()}]" + f" [{event_type}] {event}\n") @pytest.mark.asyncio async def test_run_async(self): @@ -136,9 +134,7 @@ class TestBaseStructure: artifact = {"key": "value"} artifact_name = "test_artifact" - await base_structure.save_artifact_async( - artifact, artifact_name - ) + await base_structure.save_artifact_async(artifact, artifact_name) loaded_artifact = base_structure.load_artifact(artifact_name) assert loaded_artifact == artifact @@ -151,9 +147,8 @@ class TestBaseStructure: artifact = {"key": "value"} artifact_name = "test_artifact" base_structure.save_artifact(artifact, artifact_name) - loaded_artifact = await base_structure.load_artifact_async( - artifact_name - ) + loaded_artifact = await base_structure.load_artifact_async(artifact_name + ) assert loaded_artifact == artifact @@ -170,11 +165,8 @@ class TestBaseStructure: with open(log_file) as file: lines = file.readlines() assert len(lines) == 1 - assert ( - lines[0] - == f"[{base_structure._current_timestamp()}]" - f" [{event_type}] {event}\n" - ) + assert (lines[0] == f"[{base_structure._current_timestamp()}]" + f" [{event_type}] {event}\n") @pytest.mark.asyncio async def test_asave_to_file(self, tmpdir): @@ -201,18 +193,14 @@ class TestBaseStructure: def test_run_in_thread(self): base_structure = BaseStructure() - result = base_structure.run_in_thread( - lambda: "Thread Test Result" - ) + result = base_structure.run_in_thread(lambda: "Thread Test Result") assert result.result() == "Thread Test Result" def test_save_and_decompress_data(self): base_structure = BaseStructure() data = {"key": "value"} compressed_data = base_structure.compress_data(data) - decompressed_data = base_structure.decompres_data( - compressed_data - ) + decompressed_data = base_structure.decompres_data(compressed_data) assert decompressed_data == data def test_run_batched(self): @@ -222,13 +210,11 @@ class TestBaseStructure: return f"Processed {data}" batched_data = list(range(10)) - result = base_structure.run_batched( - batched_data, batch_size=5, func=run_function - ) + result = base_structure.run_batched(batched_data, + batch_size=5, + func=run_function) - expected_result = [ - f"Processed {data}" for data in batched_data - ] + expected_result = [f"Processed {data}" for data in batched_data] assert result == expected_result def test_load_config(self, tmpdir): @@ -246,15 +232,12 @@ class TestBaseStructure: tmp_dir = tmpdir.mkdir("test_dir") base_structure = BaseStructure() data_to_backup = {"key": "value"} - base_structure.backup_data( - data_to_backup, backup_path=tmp_dir - ) + base_structure.backup_data(data_to_backup, backup_path=tmp_dir) backup_files = os.listdir(tmp_dir) assert len(backup_files) == 1 loaded_data = base_structure.load_from_file( - os.path.join(tmp_dir, backup_files[0]) - ) + os.path.join(tmp_dir, backup_files[0])) assert loaded_data == data_to_backup def test_monitor_resources(self): @@ -279,11 +262,9 @@ class TestBaseStructure: return f"Processed {data}" batched_data = list(range(10)) - result = base_structure.run_with_resources_batched( - batched_data, batch_size=5, func=run_function - ) + result = base_structure.run_with_resources_batched(batched_data, + batch_size=5, + func=run_function) - expected_result = [ - f"Processed {data}" for data in batched_data - ] + expected_result = [f"Processed {data}" for data in batched_data] assert result == expected_result diff --git a/tests/structs/test_base_workflow.py b/tests/structs/test_base_workflow.py index ccb7a563..2d94caf2 100644 --- a/tests/structs/test_base_workflow.py +++ b/tests/structs/test_base_workflow.py @@ -30,13 +30,8 @@ def test_load_workflow_state(): workflow.load_workflow_state("workflow_state.json") assert workflow.max_loops == 1 assert len(workflow.tasks) == 2 - assert ( - workflow.tasks[0].description == "What's the weather in miami" - ) - assert ( - workflow.tasks[1].description - == "Create a report on these metrics" - ) + assert (workflow.tasks[0].description == "What's the weather in miami") + assert (workflow.tasks[1].description == "Create a report on these metrics") teardown_workflow() diff --git a/tests/structs/test_concurrent_workflow.py b/tests/structs/test_concurrent_workflow.py index e3fabdd5..9a3f46da 100644 --- a/tests/structs/test_concurrent_workflow.py +++ b/tests/structs/test_concurrent_workflow.py @@ -18,9 +18,7 @@ def test_run(): workflow.add(task1) workflow.add(task2) - with patch( - "concurrent.futures.ThreadPoolExecutor" - ) as mock_executor: + with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor: future1 = Future() future1.set_result(None) future2 = Future() diff --git a/tests/structs/test_conversation.py b/tests/structs/test_conversation.py index 049f3fb3..67f083f3 100644 --- a/tests/structs/test_conversation.py +++ b/tests/structs/test_conversation.py @@ -87,16 +87,13 @@ def test_return_history_as_string_with_different_roles(role, content): @pytest.mark.parametrize("message_count", range(1, 11)) -def test_return_history_as_string_with_multiple_messages( - message_count, -): +def test_return_history_as_string_with_multiple_messages(message_count,): conv = Conversation() for i in range(message_count): conv.add("user", f"Message {i + 1}") result = conv.return_history_as_string() expected = "".join( - [f"user: Message {i + 1}\n\n" for i in range(message_count)] - ) + [f"user: Message {i + 1}\n\n" for i in range(message_count)]) assert result == expected @@ -122,10 +119,8 @@ def test_return_history_as_string_with_large_message(conversation): large_message = "Hello, world! " * 10000 # 10,000 repetitions conversation.add("user", large_message) result = conversation.return_history_as_string() - expected = ( - "user: Hello, world!\n\nassistant: Hello, user!\n\nuser:" - f" {large_message}\n\n" - ) + expected = ("user: Hello, world!\n\nassistant: Hello, user!\n\nuser:" + f" {large_message}\n\n") assert result == expected @@ -141,10 +136,8 @@ def test_export_import_conversation(conversation, tmp_path): conversation.export_conversation(filename) new_conversation = Conversation() new_conversation.import_conversation(filename) - assert ( - new_conversation.return_history_as_string() - == conversation.return_history_as_string() - ) + assert (new_conversation.return_history_as_string() == + conversation.return_history_as_string()) def test_count_messages_by_role(conversation): diff --git a/tests/structs/test_groupchat.py b/tests/structs/test_groupchat.py index e8096d9c..d9635e79 100644 --- a/tests/structs/test_groupchat.py +++ b/tests/structs/test_groupchat.py @@ -11,6 +11,7 @@ llm2 = Anthropic() # Mock the OpenAI class for testing class MockOpenAI: + def __init__(self, *args, **kwargs): pass @@ -139,9 +140,9 @@ def test_groupchat_manager_generate_reply(): selector = agent1 # Initialize GroupChatManager - manager = GroupChatManager( - groupchat=groupchat, selector=selector, openai=mocked_openai - ) + manager = GroupChatManager(groupchat=groupchat, + selector=selector, + openai=mocked_openai) # Generate a reply task = "Write me a riddle" @@ -165,9 +166,8 @@ def test_groupchat_select_speaker(): # Simulate selecting the next speaker last_speaker = agent1 - next_speaker = manager.select_speaker( - last_speaker=last_speaker, selector=selector - ) + next_speaker = manager.select_speaker(last_speaker=last_speaker, + selector=selector) # Ensure the next speaker is agent2 assert next_speaker == agent2 @@ -185,9 +185,8 @@ def test_groupchat_underpopulated_group(): # Simulate selecting the next speaker in an underpopulated group last_speaker = agent1 - next_speaker = manager.select_speaker( - last_speaker=last_speaker, selector=selector - ) + next_speaker = manager.select_speaker(last_speaker=last_speaker, + selector=selector) # Ensure the next speaker is the same as the last speaker in an underpopulated group assert next_speaker == last_speaker @@ -205,15 +204,13 @@ def test_groupchat_max_rounds(): # Simulate the conversation with max rounds last_speaker = agent1 for _ in range(2): - next_speaker = manager.select_speaker( - last_speaker=last_speaker, selector=selector - ) + next_speaker = manager.select_speaker(last_speaker=last_speaker, + selector=selector) last_speaker = next_speaker # Try one more round, should stay with the last speaker - next_speaker = manager.select_speaker( - last_speaker=last_speaker, selector=selector - ) + next_speaker = manager.select_speaker(last_speaker=last_speaker, + selector=selector) # Ensure the next speaker is the same as the last speaker after reaching max rounds assert next_speaker == last_speaker diff --git a/tests/structs/test_json.py b/tests/structs/test_json.py index 9ba11072..c0e1e851 100644 --- a/tests/structs/test_json.py +++ b/tests/structs/test_json.py @@ -15,10 +15,8 @@ def valid_schema_path(tmp_path): d = tmp_path / "sub" d.mkdir() p = d / "schema.json" - p.write_text( - '{"type": "object", "properties": {"name": {"type":' - ' "string"}}}' - ) + p.write_text('{"type": "object", "properties": {"name": {"type":' + ' "string"}}}') return str(p) @@ -33,6 +31,7 @@ def invalid_schema_path(tmp_path): # This test class must be subclassed as JSON class is abstract class TestableJSON(JSON): + def validate(self, data): # Here must be a real validation implementation for testing pass diff --git a/tests/structs/test_majority_voting.py b/tests/structs/test_majority_voting.py index dcd25f0b..b6f09020 100644 --- a/tests/structs/test_majority_voting.py +++ b/tests/structs/test_majority_voting.py @@ -35,15 +35,9 @@ def test_majority_voting_run_concurrent(mocker): majority_vote = mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with( - "What is the capital of France?" - ) - agent2.run.assert_called_once_with( - "What is the capital of France?" - ) - agent3.run.assert_called_once_with( - "What is the capital of France?" - ) + agent1.run.assert_called_once_with("What is the capital of France?") + agent2.run.assert_called_once_with("What is the capital of France?") + agent3.run.assert_called_once_with("What is the capital of France?") # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) @@ -83,15 +77,9 @@ def test_majority_voting_run_multithreaded(mocker): majority_vote = mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with( - "What is the capital of France?" - ) - agent2.run.assert_called_once_with( - "What is the capital of France?" - ) - agent3.run.assert_called_once_with( - "What is the capital of France?" - ) + agent1.run.assert_called_once_with("What is the capital of France?") + agent2.run.assert_called_once_with("What is the capital of France?") + agent3.run.assert_called_once_with("What is the capital of France?") # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) @@ -133,15 +121,9 @@ async def test_majority_voting_run_asynchronous(mocker): majority_vote = await mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with( - "What is the capital of France?" - ) - agent2.run.assert_called_once_with( - "What is the capital of France?" - ) - agent3.run.assert_called_once_with( - "What is the capital of France?" - ) + agent1.run.assert_called_once_with("What is the capital of France?") + agent2.run.assert_called_once_with("What is the capital of France?") + agent3.run.assert_called_once_with("What is the capital of France?") # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) diff --git a/tests/structs/test_message_pool.py b/tests/structs/test_message_pool.py index cfbb4df5..4dc49a2e 100644 --- a/tests/structs/test_message_pool.py +++ b/tests/structs/test_message_pool.py @@ -8,9 +8,7 @@ def test_message_pool_initialization(): agent2 = Agent(llm=OpenAIChat(), agent_name="agent1") moderator = Agent(llm=OpenAIChat(), agent_name="agent1") agents = [agent1, agent2] - message_pool = MessagePool( - agents=agents, moderator=moderator, turns=5 - ) + message_pool = MessagePool(agents=agents, moderator=moderator, turns=5) assert message_pool.agent == agents assert message_pool.moderator == moderator @@ -20,27 +18,21 @@ def test_message_pool_initialization(): def test_message_pool_add(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) - assert message_pool.messages == [ - { - "agent": agent1, - "content": "Hello, world!", - "turn": 1, - "visible_to": "all", - "logged": True, - } - ] + assert message_pool.messages == [{ + "agent": agent1, + "content": "Hello, world!", + "turn": 1, + "visible_to": "all", + "logged": True, + }] def test_message_pool_reset(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.reset() @@ -49,9 +41,7 @@ def test_message_pool_reset(): def test_message_pool_last_turn(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.last_turn() == 1 @@ -59,9 +49,7 @@ def test_message_pool_last_turn(): def test_message_pool_last_message(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.last_message == { @@ -75,28 +63,24 @@ def test_message_pool_last_message(): def test_message_pool_get_all_messages(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) - assert message_pool.get_all_messages() == [ - { - "agent": agent1, - "content": "Hello, world!", - "turn": 1, - "visible_to": "all", - "logged": True, - } - ] + assert message_pool.get_all_messages() == [{ + "agent": agent1, + "content": "Hello, world!", + "turn": 1, + "visible_to": "all", + "logged": True, + }] def test_message_pool_get_visible_messages(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent2 = Agent(agent_name="agent2") - message_pool = MessagePool( - agents=[agent1, agent2], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1, agent2], + moderator=agent1, + turns=5) message_pool.add( agent=agent1, content="Hello, agent2!", @@ -104,14 +88,10 @@ def test_message_pool_get_visible_messages(): visible_to=[agent2.agent_name], ) - assert message_pool.get_visible_messages( - agent=agent2, turn=2 - ) == [ - { - "agent": agent1, - "content": "Hello, agent2!", - "turn": 1, - "visible_to": [agent2.agent_name], - "logged": True, - } - ] + assert message_pool.get_visible_messages(agent=agent2, turn=2) == [{ + "agent": agent1, + "content": "Hello, agent2!", + "turn": 1, + "visible_to": [agent2.agent_name], + "logged": True, + }] diff --git a/tests/structs/test_model_parallizer.py b/tests/structs/test_model_parallizer.py index a0840608..54abe031 100644 --- a/tests/structs/test_model_parallizer.py +++ b/tests/structs/test_model_parallizer.py @@ -11,7 +11,9 @@ from swarms.structs.model_parallizer import ModelParallelizer # Initialize the models custom_config = { "quantize": True, - "quantization_config": {"load_in_4bit": True}, + "quantization_config": { + "load_in_4bit": True + }, "verbose": True, } huggingface_llm = HuggingfaceLLM( @@ -24,14 +26,12 @@ zeroscope_ttv = ZeroscopeTTV() def test_init(): - mp = ModelParallelizer( - [ - huggingface_llm, - mixtral, - gpt4_vision_api, - zeroscope_ttv, - ] - ) + mp = ModelParallelizer([ + huggingface_llm, + mixtral, + gpt4_vision_api, + zeroscope_ttv, + ]) assert isinstance(mp, ModelParallelizer) @@ -39,24 +39,20 @@ def test_run(): mp = ModelParallelizer([huggingface_llm]) result = mp.run( "Create a list of known biggest risks of structural collapse" - " with references" - ) + " with references") assert isinstance(result, str) def test_run_all(): - mp = ModelParallelizer( - [ - huggingface_llm, - mixtral, - gpt4_vision_api, - zeroscope_ttv, - ] - ) + mp = ModelParallelizer([ + huggingface_llm, + mixtral, + gpt4_vision_api, + zeroscope_ttv, + ]) result = mp.run_all( "Create a list of known biggest risks of structural collapse" - " with references" - ) + " with references") assert isinstance(result, list) assert len(result) == 5 @@ -75,10 +71,8 @@ def test_remove_llm(): def test_save_responses_to_file(tmp_path): mp = ModelParallelizer([huggingface_llm]) - mp.run( - "Create a list of known biggest risks of structural collapse" - " with references" - ) + mp.run("Create a list of known biggest risks of structural collapse" + " with references") file = tmp_path / "responses.txt" mp.save_responses_to_file(file) assert file.read_text() != "" @@ -86,10 +80,8 @@ def test_save_responses_to_file(tmp_path): def test_get_task_history(): mp = ModelParallelizer([huggingface_llm]) - mp.run( - "Create a list of known biggest risks of structural collapse" - " with references" - ) + mp.run("Create a list of known biggest risks of structural collapse" + " with references") assert mp.get_task_history() == [ "Create a list of known biggest risks of structural collapse" " with references" @@ -98,10 +90,8 @@ def test_get_task_history(): def test_summary(capsys): mp = ModelParallelizer([huggingface_llm]) - mp.run( - "Create a list of known biggest risks of structural collapse" - " with references" - ) + mp.run("Create a list of known biggest risks of structural collapse" + " with references") mp.summary() captured = capsys.readouterr() assert "Tasks History:" in captured.out @@ -123,8 +113,7 @@ def test_concurrent_run(): mp = ModelParallelizer([huggingface_llm, mixtral]) result = mp.concurrent_run( "Create a list of known biggest risks of structural collapse" - " with references" - ) + " with references") assert isinstance(result, list) assert len(result) == 2 diff --git a/tests/structs/test_multi_agent_collab.py b/tests/structs/test_multi_agent_collab.py index 555771e7..df710f29 100644 --- a/tests/structs/test_multi_agent_collab.py +++ b/tests/structs/test_multi_agent_collab.py @@ -73,12 +73,8 @@ def test_run(collaboration): def test_format_results(collaboration): - collaboration.results = [ - {"agent": "Agent1", "response": "Response1"} - ] - formatted_results = collaboration.format_results( - collaboration.results - ) + collaboration.results = [{"agent": "Agent1", "response": "Response1"}] + formatted_results = collaboration.format_results(collaboration.results) assert "Agent1 responded: Response1" in formatted_results @@ -112,7 +108,10 @@ def test_repr(collaboration): def test_load(collaboration): state = { "step": 5, - "results": [{"agent": "Agent1", "response": "Response1"}], + "results": [{ + "agent": "Agent1", + "response": "Response1" + }], } with open(collaboration.saved_file_path_name, "w") as file: json.dump(state, file) diff --git a/tests/structs/test_nonlinear_workflow.py b/tests/structs/test_nonlinear_workflow.py index 2544a7e4..e45e86cc 100644 --- a/tests/structs/test_nonlinear_workflow.py +++ b/tests/structs/test_nonlinear_workflow.py @@ -5,6 +5,7 @@ from swarms.structs import NonlinearWorkflow, Task class TestNonlinearWorkflow: + def test_add_task(self): llm = OpenAIChat(openai_api_key="") task = Task(llm, "What's the weather in miami") @@ -33,9 +34,7 @@ class TestNonlinearWorkflow: workflow = NonlinearWorkflow() workflow.add(task1, task2.name) workflow.add(task2, task1.name) - with pytest.raises( - Exception, match="Circular dependency detected" - ): + with pytest.raises(Exception, match="Circular dependency detected"): workflow.run() def test_run_with_stopping_token(self): diff --git a/tests/structs/test_recursive_workflow.py b/tests/structs/test_recursive_workflow.py index 5b24f921..618f955a 100644 --- a/tests/structs/test_recursive_workflow.py +++ b/tests/structs/test_recursive_workflow.py @@ -53,9 +53,7 @@ def test_run_stop_token_not_in_result(): try: workflow.run() except RecursionError: - pytest.fail( - "RecursiveWorkflow.run caused a RecursionError" - ) + pytest.fail("RecursiveWorkflow.run caused a RecursionError") assert agent.execute.call_count == max_iterations diff --git a/tests/structs/test_sequential_workflow.py b/tests/structs/test_sequential_workflow.py index 0d12991a..1b17b305 100644 --- a/tests/structs/test_sequential_workflow.py +++ b/tests/structs/test_sequential_workflow.py @@ -17,6 +17,7 @@ os.environ["OPENAI_API_KEY"] = "mocked_api_key" # Mock OpenAIChat class for testing class MockOpenAIChat: + def __init__(self, *args, **kwargs): pass @@ -26,6 +27,7 @@ class MockOpenAIChat: # Mock Agent class for testing class MockAgent: + def __init__(self, *args, **kwargs): pass @@ -35,6 +37,7 @@ class MockAgent: # Mock SequentialWorkflow class for testing class MockSequentialWorkflow: + def __init__(self, *args, **kwargs): pass @@ -69,10 +72,7 @@ def test_sequential_workflow_initialization(): assert len(workflow.tasks) == 0 assert workflow.max_loops == 1 assert workflow.autosave is False - assert ( - workflow.saved_state_filepath - == "sequential_workflow_state.json" - ) + assert (workflow.saved_state_filepath == "sequential_workflow_state.json") assert workflow.restore_state_filepath is None assert workflow.dashboard is False @@ -177,6 +177,7 @@ def test_sequential_workflow_workflow_dashboard(capfd): # Mock Agent class for async testing class MockAsyncAgent: + def __init__(self, *args, **kwargs): pass diff --git a/tests/structs/test_swarmnetwork.py b/tests/structs/test_swarmnetwork.py index 9dc6d903..6b76cd04 100644 --- a/tests/structs/test_swarmnetwork.py +++ b/tests/structs/test_swarmnetwork.py @@ -20,9 +20,7 @@ def test_swarm_network_init(swarm_network): @patch("swarms.structs.swarm_net.SwarmNetwork.logger") def test_run(mock_logger, swarm_network): swarm_network.run() - assert ( - mock_logger.info.call_count == 10 - ) # 2 log messages per agent + assert (mock_logger.info.call_count == 10) # 2 log messages per agent def test_run_with_mocked_agents(mocker, swarm_network): diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py index de0352af..20773ea7 100644 --- a/tests/structs/test_task.py +++ b/tests/structs/test_task.py @@ -7,8 +7,7 @@ from dotenv import load_dotenv from swarms.models.gpt4_vision_api import GPT4VisionAPI from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( - MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, -) + MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,) from swarms.structs.agent import Agent from swarms.structs.task import Task @@ -21,13 +20,11 @@ def llm(): def test_agent_run_task(llm): - task = ( - "Analyze this image of an assembly line and identify any" - " issues such as misaligned parts, defects, or deviations" - " from the standard assembly process. IF there is anything" - " unsafe in the image, explain why it is unsafe and how it" - " could be improved." - ) + task = ("Analyze this image of an assembly line and identify any" + " issues such as misaligned parts, defects, or deviations" + " from the standard assembly process. IF there is anything" + " unsafe in the image, explain why it is unsafe and how it" + " could be improved.") img = "assembly_line.jpg" agent = Agent( @@ -49,9 +46,7 @@ def test_agent_run_task(llm): @pytest.fixture def task(): agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)] - return Task( - id="Task_1", task="Task_Name", agents=agents, dependencies=[] - ) + return Task(id="Task_1", task="Task_Name", agents=agents, dependencies=[]) # Basic tests @@ -189,9 +184,7 @@ def test_task_execute_with_condition(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" condition = mocker.Mock(return_value=True) - task = Task( - description="Test task", agent=mock_agent, condition=condition - ) + task = Task(description="Test task", agent=mock_agent, condition=condition) task.execute() assert task.result == "result" assert task.history == ["result"] @@ -201,9 +194,7 @@ def test_task_execute_with_condition_false(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" condition = mocker.Mock(return_value=False) - task = Task( - description="Test task", agent=mock_agent, condition=condition - ) + task = Task(description="Test task", agent=mock_agent, condition=condition) task.execute() assert task.result is None assert task.history == [] @@ -213,9 +204,7 @@ def test_task_execute_with_action(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" action = mocker.Mock() - task = Task( - description="Test task", agent=mock_agent, action=action - ) + task = Task(description="Test task", agent=mock_agent, action=action) task.execute() assert task.result == "result" assert task.history == ["result"] @@ -243,11 +232,9 @@ def test_task_handle_scheduled_task_future(mocker): agent=mock_agent, schedule_time=datetime.now() + timedelta(days=1), ) - with mocker.patch.object( - task.scheduler, "enter" - ) as mock_enter, mocker.patch.object( - task.scheduler, "run" - ) as mock_run: + with mocker.patch.object(task.scheduler, + "enter") as mock_enter, mocker.patch.object( + task.scheduler, "run") as mock_run: task.handle_scheduled_task() mock_enter.assert_called_once() mock_run.assert_called_once() diff --git a/tests/structs/test_taskqueuebase.py b/tests/structs/test_taskqueuebase.py index 512f72ae..e6648881 100644 --- a/tests/structs/test_taskqueuebase.py +++ b/tests/structs/test_taskqueuebase.py @@ -21,7 +21,9 @@ def agent(): @pytest.fixture() def concrete_task_queue(): + class ConcreteTaskQueue(TaskQueueBase): + def add_task(self, task): pass # Here you would add concrete implementation of add_task @@ -51,9 +53,8 @@ def test_add_task_failure(concrete_task_queue, task): # Assuming the task is somehow invalid # Note: Concrete implementation requires logic defining what an invalid task is concrete_task_queue.add_task(task) - assert ( - concrete_task_queue.add_task(task) is False - ) # Adding the same task again + assert (concrete_task_queue.add_task(task) + is False) # Adding the same task again @pytest.mark.parametrize("invalid_task", [None, "", {}, []]) diff --git a/tests/structs/test_team.py b/tests/structs/test_team.py index 44d64e18..92778134 100644 --- a/tests/structs/test_team.py +++ b/tests/structs/test_team.py @@ -7,6 +7,7 @@ from swarms.structs.team import Team class TestTeam(unittest.TestCase): + def setUp(self): self.agent = Agent( llm=OpenAIChat(openai_api_key=""), @@ -30,16 +31,17 @@ class TestTeam(unittest.TestCase): with self.assertRaises(ValueError): self.team.check_config( - {"config": json.dumps({"agents": [], "tasks": []})} - ) + {"config": json.dumps({ + "agents": [], + "tasks": [] + })}) def test_run(self): self.assertEqual(self.team.run(), self.task.execute()) def test_sequential_loop(self): - self.assertEqual( - self.team._Team__sequential_loop(), self.task.execute() - ) + self.assertEqual(self.team._Team__sequential_loop(), + self.task.execute()) def test_log(self): self.assertIsNone(self.team._Team__log("Test message")) diff --git a/tests/structs/test_tests_graph_workflow.py b/tests/structs/test_tests_graph_workflow.py index cb5b17a7..b35ff6f3 100644 --- a/tests/structs/test_tests_graph_workflow.py +++ b/tests/structs/test_tests_graph_workflow.py @@ -27,9 +27,7 @@ def test_set_entry_point(graph_workflow): def test_set_entry_point_nonexistent_node(graph_workflow): - with pytest.raises( - ValueError, match="Node does not exist in graph" - ): + with pytest.raises(ValueError, match="Node does not exist in graph"): graph_workflow.set_entry_point("nonexistent") @@ -42,29 +40,23 @@ def test_add_edge(graph_workflow): def test_add_edge_nonexistent_node(graph_workflow): graph_workflow.add("node1", "value1") - with pytest.raises( - ValueError, match="Node does not exist in graph" - ): + with pytest.raises(ValueError, match="Node does not exist in graph"): graph_workflow.add_edge("node1", "nonexistent") def test_add_conditional_edges(graph_workflow): graph_workflow.add("node1", "value1") graph_workflow.add("node2", "value2") - graph_workflow.add_conditional_edges( - "node1", "condition1", {"condition_value1": "node2"} - ) + graph_workflow.add_conditional_edges("node1", "condition1", + {"condition_value1": "node2"}) assert "node2" in graph_workflow.graph["node1"]["edges"] def test_add_conditional_edges_nonexistent_node(graph_workflow): graph_workflow.add("node1", "value1") - with pytest.raises( - ValueError, match="Node does not exist in graph" - ): + with pytest.raises(ValueError, match="Node does not exist in graph"): graph_workflow.add_conditional_edges( - "node1", "condition1", {"condition_value1": "nonexistent"} - ) + "node1", "condition1", {"condition_value1": "nonexistent"}) def test_run(graph_workflow): diff --git a/tests/telemetry/test_posthog_utils.py b/tests/telemetry/test_posthog_utils.py index 0364cb3a..dd468d1d 100644 --- a/tests/telemetry/test_posthog_utils.py +++ b/tests/telemetry/test_posthog_utils.py @@ -35,9 +35,8 @@ def test_log_activity_posthog(mock_posthog, mock_env): test_function() # Check if the Posthog capture method was called with the expected arguments - mock_posthog.capture.assert_called_once_with( - "test_user_id", event_name, event_properties - ) + mock_posthog.capture.assert_called_once_with("test_user_id", event_name, + event_properties) # Test a scenario where environment variables are not set diff --git a/tests/telemetry/test_user_utils.py b/tests/telemetry/test_user_utils.py index c7b5962c..5fc300e2 100644 --- a/tests/telemetry/test_user_utils.py +++ b/tests/telemetry/test_user_utils.py @@ -46,9 +46,7 @@ def test_generate_unique_identifier(): # Generate unique identifiers and ensure they are valid UUID strings unique_id = generate_unique_identifier() assert isinstance(unique_id, str) - assert uuid.UUID( - unique_id, version=5, namespace=uuid.NAMESPACE_DNS - ) + assert uuid.UUID(unique_id, version=5, namespace=uuid.NAMESPACE_DNS) def test_generate_user_id_edge_case(): @@ -73,9 +71,7 @@ def test_get_system_info_edge_case(): # Test get_system_info for consistency system_info1 = get_system_info() system_info2 = get_system_info() - assert ( - system_info1 == system_info2 - ) # Ensure system info remains the same + assert (system_info1 == system_info2) # Ensure system info remains the same def test_generate_unique_identifier_edge_case(): diff --git a/tests/test_upload_tests_to_issues.py b/tests/test_upload_tests_to_issues.py index 0857c58a..c916383e 100644 --- a/tests/test_upload_tests_to_issues.py +++ b/tests/test_upload_tests_to_issues.py @@ -20,9 +20,7 @@ headers = { def run_pytest(): - result = subprocess.run( - ["pytest"], capture_output=True, text=True - ) + result = subprocess.run(["pytest"], capture_output=True, text=True) return result.stdout + result.stderr @@ -56,9 +54,7 @@ def main(): errors = parse_pytest_output(pytest_output) for error in errors: - issue_response = create_github_issue( - error["title"], error["body"] - ) + issue_response = create_github_issue(error["title"], error["body"]) print(f"Issue created: {issue_response.get('html_url')}") diff --git a/tests/tokenizers/test_anthropictokenizer.py b/tests/tokenizers/test_anthropictokenizer.py index 14b2fd86..a0664bde 100644 --- a/tests/tokenizers/test_anthropictokenizer.py +++ b/tests/tokenizers/test_anthropictokenizer.py @@ -16,9 +16,8 @@ def test_default_max_tokens(): assert tokenizer.default_max_tokens() == 100000 -@pytest.mark.parametrize( - "model,tokens", [("claude-2.1", 200000), ("claude", 100000)] -) +@pytest.mark.parametrize("model,tokens", [("claude-2.1", 200000), + ("claude", 100000)]) def test_default_max_tokens_models(model, tokens): tokenizer = AnthropicTokenizer(model=model) assert tokenizer.default_max_tokens() == tokens diff --git a/tests/tokenizers/test_basetokenizer.py b/tests/tokenizers/test_basetokenizer.py index 3956d2de..c41a0545 100644 --- a/tests/tokenizers/test_basetokenizer.py +++ b/tests/tokenizers/test_basetokenizer.py @@ -18,9 +18,7 @@ def test_post_init(base_tokenizer): # 3. Tests for count_tokens_left with different inputs. -def test_count_tokens_left_with_positive_diff( - base_tokenizer, monkeypatch -): +def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch): # Mocking count_tokens to return a specific value monkeypatch.setattr( "swarms.tokenizers.BaseTokenizer.count_tokens", @@ -29,9 +27,7 @@ def test_count_tokens_left_with_positive_diff( assert base_tokenizer.count_tokens_left("some text") == 50 -def test_count_tokens_left_with_zero_diff( - base_tokenizer, monkeypatch -): +def test_count_tokens_left_with_zero_diff(base_tokenizer, monkeypatch): monkeypatch.setattr( "swarms.tokenizers.BaseTokenizer.count_tokens", lambda x, y: 100, diff --git a/tests/tokenizers/test_huggingfacetokenizer.py b/tests/tokenizers/test_huggingfacetokenizer.py index 1eedb6e5..dad69c1f 100644 --- a/tests/tokenizers/test_huggingfacetokenizer.py +++ b/tests/tokenizers/test_huggingfacetokenizer.py @@ -51,18 +51,10 @@ def test_prefix_space_tokens(hftokenizer): # testing _maybe_add_prefix_space method def test__maybe_add_prefix_space(hftokenizer): - assert ( - hftokenizer._maybe_add_prefix_space( - [101, 2003, 2010, 2050, 2001, 2339], " is why" - ) - == " is why" - ) - assert ( - hftokenizer._maybe_add_prefix_space( - [2003, 2010, 2050, 2001, 2339], "is why" - ) - == " is why" - ) + assert (hftokenizer._maybe_add_prefix_space( + [101, 2003, 2010, 2050, 2001, 2339], " is why") == " is why") + assert (hftokenizer._maybe_add_prefix_space([2003, 2010, 2050, 2001, 2339], + "is why") == " is why") # continuing tests for other methods... diff --git a/tests/tokenizers/test_openaitokenizer.py b/tests/tokenizers/test_openaitokenizer.py index 3c24748d..8d987164 100644 --- a/tests/tokenizers/test_openaitokenizer.py +++ b/tests/tokenizers/test_openaitokenizer.py @@ -18,31 +18,21 @@ def test_default_max_tokens(openai_tokenizer): assert openai_tokenizer.default_max_tokens() == 4096 -@pytest.mark.parametrize( - "text, expected_output", [("Hello, world!", 3), (["Hello"], 4)] -) +@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3), + (["Hello"], 4)]) def test_count_tokens_single(openai_tokenizer, text, expected_output): - assert ( - openai_tokenizer.count_tokens(text, "gpt-3") - == expected_output - ) + assert (openai_tokenizer.count_tokens(text, "gpt-3") == expected_output) @pytest.mark.parametrize( "texts, expected_output", [(["Hello, world!", "This is a test"], 6), (["Hello"], 4)], ) -def test_count_tokens_multiple( - openai_tokenizer, texts, expected_output -): - assert ( - openai_tokenizer.count_tokens(texts, "gpt-3") - == expected_output - ) +def test_count_tokens_multiple(openai_tokenizer, texts, expected_output): + assert (openai_tokenizer.count_tokens(texts, "gpt-3") == expected_output) -@pytest.mark.parametrize( - "text, expected_output", [("Hello, world!", 3), (["Hello"], 4)] -) +@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3), + (["Hello"], 4)]) def test_len(openai_tokenizer, text, expected_output): assert openai_tokenizer.len(text, "gpt-3") == expected_output diff --git a/tests/tokenizers/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py index b868f0a1..ecd85097 100644 --- a/tests/tokenizers/test_tokenizer.py +++ b/tests/tokenizers/test_tokenizer.py @@ -7,9 +7,7 @@ from swarms.tokenizers.r_tokenizers import Tokenizer def test_initializer_existing_model_file(): with patch("os.path.exists", return_value=True): - with patch( - "swarms.tokenizers.SentencePieceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") mock_model.assert_called_with("tokenizers/my_model.model") assert tokenizer.model == mock_model.return_value @@ -17,66 +15,43 @@ def test_initializer_existing_model_file(): def test_initializer_model_folder(): with patch("os.path.exists", side_effect=[False, True]): - with patch( - "swarms.tokenizers.HuggingFaceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.HuggingFaceTokenizer") as mock_model: tokenizer = Tokenizer("my_model_directory") mock_model.assert_called_with("my_model_directory") assert tokenizer.model == mock_model.return_value def test_vocab_size(): - with patch( - "swarms.tokenizers.SentencePieceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert ( - tokenizer.vocab_size == mock_model.return_value.vocab_size - ) + assert (tokenizer.vocab_size == mock_model.return_value.vocab_size) def test_bos_token_id(): - with patch( - "swarms.tokenizers.SentencePieceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert ( - tokenizer.bos_token_id - == mock_model.return_value.bos_token_id - ) + assert (tokenizer.bos_token_id == mock_model.return_value.bos_token_id) def test_encode(): - with patch( - "swarms.tokenizers.SentencePieceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert ( - tokenizer.encode("hello") - == mock_model.return_value.encode.return_value - ) + assert (tokenizer.encode("hello") == + mock_model.return_value.encode.return_value) def test_decode(): - with patch( - "swarms.tokenizers.SentencePieceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert ( - tokenizer.decode([1, 2, 3]) - == mock_model.return_value.decode.return_value - ) + assert (tokenizer.decode( + [1, 2, 3]) == mock_model.return_value.decode.return_value) def test_call(): - with patch( - "swarms.tokenizers.SentencePieceTokenizer" - ) as mock_model: + with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") assert ( - tokenizer("hello") - == mock_model.return_value.__call__.return_value - ) + tokenizer("hello") == mock_model.return_value.__call__.return_value) # More tests can be added here diff --git a/tests/tools/test_tools_base.py b/tests/tools/test_tools_base.py index 9060f53f..dea29d9c 100644 --- a/tests/tools/test_tools_base.py +++ b/tests/tools/test_tools_base.py @@ -65,23 +65,20 @@ def test_structured_tool_invoke(): def test_tool_creation(): - tool = Tool( - name="test_tool", func=lambda x: x, description="Test tool" - ) + tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") assert tool.name == "test_tool" assert tool.func is not None assert tool.description == "Test tool" def test_tool_ainvoke(): - tool = Tool( - name="test_tool", func=lambda x: x, description="Test tool" - ) + tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") result = tool.ainvoke("input_data") assert result == "input_data" def test_tool_ainvoke_with_coroutine(): + async def async_function(input_data): return input_data @@ -95,6 +92,7 @@ def test_tool_ainvoke_with_coroutine(): def test_tool_args(): + def sample_function(input_data): return input_data @@ -110,6 +108,7 @@ def test_tool_args(): def test_structured_tool_creation(): + class SampleArgsSchema: pass @@ -126,6 +125,7 @@ def test_structured_tool_creation(): def test_structured_tool_ainvoke(): + class SampleArgsSchema: pass @@ -140,6 +140,7 @@ def test_structured_tool_ainvoke(): def test_structured_tool_ainvoke_with_coroutine(): + class SampleArgsSchema: pass @@ -157,6 +158,7 @@ def test_structured_tool_ainvoke_with_coroutine(): def test_structured_tool_args(): + class SampleArgsSchema: pass @@ -182,14 +184,13 @@ def test_tool_ainvoke_exception(): def test_tool_ainvoke_with_coroutine_exception(): - tool = Tool( - name="test_tool", coroutine=None, description="Test tool" - ) + tool = Tool(name="test_tool", coroutine=None, description="Test tool") with pytest.raises(NotImplementedError): tool.ainvoke("input_data") def test_structured_tool_ainvoke_exception(): + class SampleArgsSchema: pass @@ -204,6 +205,7 @@ def test_structured_tool_ainvoke_exception(): def test_structured_tool_ainvoke_with_coroutine_exception(): + class SampleArgsSchema: pass @@ -225,6 +227,7 @@ def test_tool_description_not_provided(): def test_tool_invoke_with_callbacks(): + def sample_function(input_data, callbacks=None): if callbacks: callbacks.on_start() @@ -240,6 +243,7 @@ def test_tool_invoke_with_callbacks(): def test_tool_invoke_with_new_argument(): + def sample_function(input_data, callbacks=None): return input_data @@ -249,6 +253,7 @@ def test_tool_invoke_with_new_argument(): def test_tool_ainvoke_with_new_argument(): + async def async_function(input_data, callbacks=None): return input_data @@ -258,6 +263,7 @@ def test_tool_ainvoke_with_new_argument(): def test_tool_description_from_docstring(): + def sample_function(input_data): """Sample function docstring""" return input_data @@ -267,6 +273,7 @@ def test_tool_description_from_docstring(): def test_tool_ainvoke_with_exceptions(): + async def async_function(input_data): raise ValueError("Test exception") @@ -279,6 +286,7 @@ def test_tool_ainvoke_with_exceptions(): def test_structured_tool_infer_schema_false(): + def sample_function(input_data): return input_data @@ -292,6 +300,7 @@ def test_structured_tool_infer_schema_false(): def test_structured_tool_ainvoke_with_callbacks(): + class SampleArgsSchema: pass @@ -307,15 +316,14 @@ def test_structured_tool_ainvoke_with_callbacks(): args_schema=SampleArgsSchema, ) callbacks = MagicMock() - result = tool.ainvoke( - {"tool_input": "input_data"}, callbacks=callbacks - ) + result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks) assert result == "input_data" callbacks.on_start.assert_called_once() callbacks.on_finish.assert_called_once() def test_structured_tool_description_not_provided(): + class SampleArgsSchema: pass @@ -330,6 +338,7 @@ def test_structured_tool_description_not_provided(): def test_structured_tool_args_schema(): + class SampleArgsSchema: pass @@ -345,6 +354,7 @@ def test_structured_tool_args_schema(): def test_structured_tool_args_schema_inference(): + def sample_function(input_data): return input_data @@ -358,6 +368,7 @@ def test_structured_tool_args_schema_inference(): def test_structured_tool_ainvoke_with_new_argument(): + class SampleArgsSchema: pass @@ -369,13 +380,12 @@ def test_structured_tool_ainvoke_with_new_argument(): func=sample_function, args_schema=SampleArgsSchema, ) - result = tool.ainvoke( - {"tool_input": "input_data"}, callbacks=None - ) + result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None) assert result == "input_data" def test_structured_tool_ainvoke_with_exceptions(): + class SampleArgsSchema: pass @@ -461,9 +471,7 @@ def test_tool_with_runnable(mock_runnable): def test_tool_with_invalid_argument(): # Test passing an invalid argument type with pytest.raises(ValueError): - tool( - 123 - ) # Using an integer instead of a string/callable/Runnable + tool(123) # Using an integer instead of a string/callable/Runnable def test_tool_with_multiple_arguments(mock_func): @@ -525,9 +533,7 @@ class MockSchema(BaseModel): # Test suite starts here class TestTool: # Basic Functionality Tests - def test_tool_with_valid_callable_creates_base_tool( - self, mock_func - ): + def test_tool_with_valid_callable_creates_base_tool(self, mock_func): result = tool(mock_func) assert isinstance(result, BaseTool) @@ -560,6 +566,7 @@ class TestTool: # Error Handling Tests def test_tool_raises_error_without_docstring(self): + def no_doc_func(arg: str) -> str: return arg @@ -567,14 +574,14 @@ class TestTool: tool(no_doc_func) def test_tool_raises_error_runnable_without_object_schema( - self, mock_runnable - ): + self, mock_runnable): with pytest.raises(ValueError): tool(mock_runnable) # Decorator Behavior Tests @pytest.mark.asyncio async def test_async_tool_function(self): + @tool async def async_func(arg: str) -> str: return arg @@ -597,6 +604,7 @@ class TestTool: pass def test_tool_with_different_return_types(self): + @tool def return_int(arg: str) -> int: return int(arg) @@ -615,6 +623,7 @@ class TestTool: # Test with multiple arguments def test_tool_with_multiple_args(self): + @tool def concat_strings(a: str, b: str) -> str: return a + b @@ -624,6 +633,7 @@ class TestTool: # Test handling of optional arguments def test_tool_with_optional_args(self): + @tool def greet(name: str, greeting: str = "Hello") -> str: return f"{greeting} {name}" @@ -633,6 +643,7 @@ class TestTool: # Test with variadic arguments def test_tool_with_variadic_args(self): + @tool def sum_numbers(*numbers: int) -> int: return sum(numbers) @@ -642,6 +653,7 @@ class TestTool: # Test with keyword arguments def test_tool_with_kwargs(self): + @tool def build_query(**kwargs) -> str: return "&".join(f"{k}={v}" for k, v in kwargs.items()) @@ -651,6 +663,7 @@ class TestTool: # Test with mixed types of arguments def test_tool_with_mixed_args(self): + @tool def mixed_args(a: int, b: str, *args, **kwargs) -> str: return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}" @@ -659,6 +672,7 @@ class TestTool: # Test error handling with incorrect types def test_tool_error_with_incorrect_types(self): + @tool def add_numbers(a: int, b: int) -> int: return a + b @@ -668,6 +682,7 @@ class TestTool: # Test with nested tools def test_nested_tools(self): + @tool def inner_tool(arg: str) -> str: return f"Inner {arg}" @@ -679,6 +694,7 @@ class TestTool: assert outer_tool("Test") == "Outer Inner Test" def test_tool_with_global_variable(self): + @tool def access_global(arg: str) -> str: return f"{global_var} {arg}" @@ -701,6 +717,7 @@ class TestTool: # Test with complex data structures def test_tool_with_complex_data_structures(self): + @tool def process_data(data: dict) -> list: return [data[key] for key in sorted(data.keys())] @@ -710,6 +727,7 @@ class TestTool: # Test handling exceptions within the tool function def test_tool_handling_internal_exceptions(self): + @tool def function_that_raises(arg: str): if arg == "error": @@ -722,6 +740,7 @@ class TestTool: # Test with functions returning None def test_tool_with_none_return(self): + @tool def return_none(arg: str): return None @@ -735,7 +754,9 @@ class TestTool: # Test with class methods def test_tool_with_class_method(self): + class MyClass: + @tool def method(self, arg: str) -> str: return f"Method {arg}" @@ -745,12 +766,15 @@ class TestTool: # Test tool function with inheritance def test_tool_with_inheritance(self): + class Parent: + @tool def parent_method(self, arg: str) -> str: return f"Parent {arg}" class Child(Parent): + @tool def child_method(self, arg: str) -> str: return f"Child {arg}" @@ -761,7 +785,9 @@ class TestTool: # Test with decorators stacking def test_tool_with_multiple_decorators(self): + def another_decorator(func): + def wrapper(*args, **kwargs): return f"Decorated {func(*args, **kwargs)}" @@ -787,9 +813,7 @@ class TestTool: def thread_target(): results.append(threaded_function(5)) - threads = [ - threading.Thread(target=thread_target) for _ in range(10) - ] + threads = [threading.Thread(target=thread_target) for _ in range(10)] for t in threads: t.start() for t in threads: @@ -799,6 +823,7 @@ class TestTool: # Test with recursive functions def test_tool_with_recursive_function(self): + @tool def recursive_function(n: int) -> int: if n == 0: diff --git a/tests/utils/test_check_device.py b/tests/utils/test_check_device.py index 503a3774..1b07ce08 100644 --- a/tests/utils/test_check_device.py +++ b/tests/utils/test_check_device.py @@ -19,9 +19,8 @@ def test_check_device_no_cuda(monkeypatch): def test_check_device_cuda_exception(monkeypatch): # Mock torch.cuda.is_available to raise an exception - monkeypatch.setattr( - torch.cuda, "is_available", lambda: 1 / 0 - ) # Raises ZeroDivisionError + monkeypatch.setattr(torch.cuda, "is_available", + lambda: 1 / 0) # Raises ZeroDivisionError result = check_device(log_level=logging.DEBUG) assert result.type == "cpu" @@ -33,12 +32,8 @@ def test_check_device_one_cuda(monkeypatch): # Mock torch.cuda.device_count to return 1 monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 - monkeypatch.setattr( - torch.cuda, "memory_allocated", lambda device: 0 - ) - monkeypatch.setattr( - torch.cuda, "memory_reserved", lambda device: 0 - ) + monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) result = check_device(log_level=logging.DEBUG) assert len(result) == 1 @@ -52,12 +47,8 @@ def test_check_device_multiple_cuda(monkeypatch): # Mock torch.cuda.device_count to return 4 monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 - monkeypatch.setattr( - torch.cuda, "memory_allocated", lambda device: 0 - ) - monkeypatch.setattr( - torch.cuda, "memory_reserved", lambda device: 0 - ) + monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) result = check_device(log_level=logging.DEBUG) assert len(result) == 4 diff --git a/tests/utils/test_class_args_wrapper.py b/tests/utils/test_class_args_wrapper.py index 99d38b2c..6af5ea0d 100644 --- a/tests/utils/test_class_args_wrapper.py +++ b/tests/utils/test_class_args_wrapper.py @@ -19,8 +19,7 @@ def test_print_class_parameters_agent(): # Replace with the expected output for Agent class expected_output = ( "Parameter: name, Type: \nParameter: age, Type:" - " " - ) + " ") assert output == expected_output @@ -33,9 +32,7 @@ def test_print_class_parameters_error(): def get_parameters(class_name: str): classes = {"Agent": Agent} if class_name in classes: - return print_class_parameters( - classes[class_name], api_format=True - ) + return print_class_parameters(classes[class_name], api_format=True) else: return {"error": "Class not found"} diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index 9be83be4..d17ea08d 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -32,18 +32,14 @@ def test_multiple_gpus_available(mocker): def test_device_properties(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch( - "torch.cuda.get_device_capability", return_value=(7, 5) - ) + mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch( - "torch.cuda.get_device_name", return_value="Tesla K80" - ) + mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") devices = check_device() assert len(devices) == 1 assert str(devices[0]) == "cuda" @@ -52,27 +48,21 @@ def test_device_properties(mocker): def test_memory_threshold(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch( - "torch.cuda.get_device_capability", return_value=(7, 5) - ) + mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) - mocker.patch( - "torch.cuda.memory_allocated", return_value=900 - ) # 90% of total memory + mocker.patch("torch.cuda.memory_allocated", + return_value=900) # 90% of total memory mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch( - "torch.cuda.get_device_name", return_value="Tesla K80" - ) + mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") with pytest.warns( - UserWarning, - match=r"Memory usage for device cuda exceeds threshold", + UserWarning, + match=r"Memory usage for device cuda exceeds threshold", ): devices = check_device( - memory_threshold=0.8 - ) # Set memory threshold to 80% + memory_threshold=0.8) # Set memory threshold to 80% assert len(devices) == 1 assert str(devices[0]) == "cuda" @@ -80,27 +70,21 @@ def test_memory_threshold(mocker): def test_compute_capability_threshold(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch( - "torch.cuda.get_device_capability", return_value=(3, 0) - ) # Compute capability 3.0 + mocker.patch("torch.cuda.get_device_capability", + return_value=(3, 0)) # Compute capability 3.0 mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch( - "torch.cuda.get_device_name", return_value="Tesla K80" - ) + mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") with pytest.warns( - UserWarning, - match=( - r"Compute capability for device cuda is below threshold" - ), + UserWarning, + match=(r"Compute capability for device cuda is below threshold"), ): devices = check_device( - capability_threshold=3.5 - ) # Set compute capability threshold to 3.5 + capability_threshold=3.5) # Set compute capability threshold to 3.5 assert len(devices) == 1 assert str(devices[0]) == "cuda" diff --git a/tests/utils/test_display_markdown_message.py b/tests/utils/test_display_markdown_message.py index 1b7cadaa..9a9b9327 100644 --- a/tests/utils/test_display_markdown_message.py +++ b/tests/utils/test_display_markdown_message.py @@ -14,8 +14,7 @@ def test_basic_message(): with mock.patch.object(Console, "print") as mock_print: display_markdown_message("This is a test") mock_print.assert_called_once_with( - Markdown("This is a test", style="cyan") - ) + Markdown("This is a test", style="cyan")) def test_empty_message(): @@ -31,8 +30,7 @@ def test_colors(color): with mock.patch.object(Console, "print") as mock_print: display_markdown_message("This is a test", color) mock_print.assert_called_once_with( - Markdown("This is a test", style=color) - ) + Markdown("This is a test", style=color)) def test_dash_line(): diff --git a/tests/utils/test_extract_code_from_markdown.py b/tests/utils/test_extract_code_from_markdown.py index eb1a3e5d..8bbee570 100644 --- a/tests/utils/test_extract_code_from_markdown.py +++ b/tests/utils/test_extract_code_from_markdown.py @@ -22,12 +22,8 @@ def markdown_content_without_code(): """ -def test_extract_code_from_markdown_with_code( - markdown_content_with_code, -): - extracted_code = extract_code_from_markdown( - markdown_content_with_code - ) +def test_extract_code_from_markdown_with_code(markdown_content_with_code,): + extracted_code = extract_code_from_markdown(markdown_content_with_code) assert "def my_func():" in extracted_code assert 'print("This is my function.")' in extracted_code assert "class MyClass:" in extracted_code @@ -35,11 +31,8 @@ def test_extract_code_from_markdown_with_code( def test_extract_code_from_markdown_without_code( - markdown_content_without_code, -): - extracted_code = extract_code_from_markdown( - markdown_content_without_code - ) + markdown_content_without_code,): + extracted_code = extract_code_from_markdown(markdown_content_without_code) assert extracted_code == "" diff --git a/tests/utils/test_find_image_path.py b/tests/utils/test_find_image_path.py index 29b1c627..15de0f54 100644 --- a/tests/utils/test_find_image_path.py +++ b/tests/utils/test_find_image_path.py @@ -8,12 +8,8 @@ from swarms.utils import find_image_path def test_find_image_path_no_images(): - assert ( - find_image_path( - "This is a test string without any image paths." - ) - is None - ) + assert (find_image_path("This is a test string without any image paths.") + is None) def test_find_image_path_one_image(): @@ -23,9 +19,7 @@ def test_find_image_path_one_image(): def test_find_image_path_multiple_images(): text = "This string has two image paths: img1.png, and img2.jpg." - assert ( - find_image_path(text) == "img2.jpg" - ) # Assuming both images exist + assert (find_image_path(text) == "img2.jpg") # Assuming both images exist def test_find_image_path_wrong_input(): diff --git a/tests/utils/test_limit_tokens_from_string.py b/tests/utils/test_limit_tokens_from_string.py index 4d68dccb..cc869dd6 100644 --- a/tests/utils/test_limit_tokens_from_string.py +++ b/tests/utils/test_limit_tokens_from_string.py @@ -4,14 +4,11 @@ from swarms.utils import limit_tokens_from_string def test_limit_tokens_from_string(): - sentence = ( - "This is a test sentence. It is used for testing the number" - " of tokens." - ) + sentence = ("This is a test sentence. It is used for testing the number" + " of tokens.") limited = limit_tokens_from_string(sentence, limit=5) - assert ( - len(limited.split()) <= 5 - ), "The output string has more than 5 tokens." + assert (len(limited.split()) + <= 5), "The output string has more than 5 tokens." def test_limit_zero_tokens(): @@ -21,26 +18,21 @@ def test_limit_zero_tokens(): def test_negative_token_limit(): - sentence = ( - "This test will raise an exception when limit is negative." - ) + sentence = ("This test will raise an exception when limit is negative.") with pytest.raises(Exception): limit_tokens_from_string(sentence, limit=-1) -@pytest.mark.parametrize( - "sentence, model", [("Some sentence", "unavailable-model")] -) +@pytest.mark.parametrize("sentence, model", + [("Some sentence", "unavailable-model")]) def test_unknown_model(sentence, model): with pytest.raises(Exception): limit_tokens_from_string(sentence, model=model) def test_string_token_limit_exceeded(): - sentence = ( - "This is a long sentence with more than twenty tokens which" - " is used for testing. It checks whether the function" - " correctly limits the tokens to a specified amount." - ) + sentence = ("This is a long sentence with more than twenty tokens which" + " is used for testing. It checks whether the function" + " correctly limits the tokens to a specified amount.") limited = limit_tokens_from_string(sentence, limit=20) assert len(limited.split()) <= 20, "The token limit is exceeded." diff --git a/tests/utils/test_load_model_torch.py b/tests/utils/test_load_model_torch.py index c2018c6a..89bd6960 100644 --- a/tests/utils/test_load_model_torch.py +++ b/tests/utils/test_load_model_torch.py @@ -6,6 +6,7 @@ from swarms.utils import load_model_torch class DummyModel(nn.Module): + def __init__(self): super().__init__() self.fc = nn.Linear(10, 2) @@ -25,9 +26,7 @@ def test_load_model_torch_success(tmp_path): model_loaded = load_model_torch(model_path, model=DummyModel()) # Check if loaded model has the same architecture - assert isinstance( - model_loaded, DummyModel - ), "Loaded model type mismatch." + assert isinstance(model_loaded, DummyModel), "Loaded model type mismatch." # Test case 2: Test if function raises FileNotFoundError for non-existent file @@ -66,13 +65,12 @@ def test_load_model_torch_device_handling(tmp_path): # Define a device other than default and load the model to the specified device device = torch.device("cpu") - model_loaded = load_model_torch( - model_path, model=DummyModel(), device=device - ) + model_loaded = load_model_torch(model_path, + model=DummyModel(), + device=device) - assert ( - model_loaded.fc.weight.device == device - ), "Model not loaded to specified device." + assert (model_loaded.fc.weight.device == device + ), "Model not loaded to specified device." # Test case 6: Testing for correct handling of '*args' and '**kwargs' @@ -82,15 +80,14 @@ def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path): torch.save(model.state_dict(), model_path) def mock_torch_load(*args, **kwargs): - assert ( - "pickle_module" in kwargs - ), "Keyword arguments not passed to 'torch.load'." + assert ("pickle_module" + in kwargs), "Keyword arguments not passed to 'torch.load'." # Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly monkeypatch.setattr(torch, "load", mock_torch_load) - load_model_torch( - model_path, model=DummyModel(), pickle_module="dummy_module" - ) + load_model_torch(model_path, + model=DummyModel(), + pickle_module="dummy_module") # Test case 7: Test for model loading on CPU if no GPU is available @@ -103,9 +100,8 @@ def test_load_model_torch_cpu(tmp_path): return False # Monkeypatch to simulate no GPU available - pytest.MonkeyPatch.setattr( - torch.cuda, "is_available", mock_torch_cuda_is_available - ) + pytest.MonkeyPatch.setattr(torch.cuda, "is_available", + mock_torch_cuda_is_available) model_loaded = load_model_torch(model_path, model=DummyModel()) # Ensure model is loaded on CPU diff --git a/tests/utils/test_load_models_torch.py b/tests/utils/test_load_models_torch.py index 3f09f411..12b7605c 100644 --- a/tests/utils/test_load_models_torch.py +++ b/tests/utils/test_load_models_torch.py @@ -42,15 +42,13 @@ def test_load_model_torch_model_specified(mocker): mock_model = MagicMock(spec=torch.nn.Module) mocker.patch("torch.load", return_value={"key": "value"}) load_model_torch("model_path", model=mock_model) - mock_model.load_state_dict.assert_called_once_with( - {"key": "value"}, strict=True - ) + mock_model.load_state_dict.assert_called_once_with({"key": "value"}, + strict=True) def test_load_model_torch_model_specified_strict_false(mocker): mock_model = MagicMock(spec=torch.nn.Module) mocker.patch("torch.load", return_value={"key": "value"}) load_model_torch("model_path", model=mock_model, strict=False) - mock_model.load_state_dict.assert_called_once_with( - {"key": "value"}, strict=False - ) + mock_model.load_state_dict.assert_called_once_with({"key": "value"}, + strict=False) diff --git a/tests/utils/test_math_eval.py b/tests/utils/test_math_eval.py index ae7ee04c..e3ddad9f 100644 --- a/tests/utils/test_math_eval.py +++ b/tests/utils/test_math_eval.py @@ -18,6 +18,7 @@ def func2_with_exception(x): def test_same_results_no_exception(caplog): + @math_eval(func1_no_exception, func2_no_exception) def test_func(x): return x @@ -28,6 +29,7 @@ def test_same_results_no_exception(caplog): def test_func1_exception(caplog): + @math_eval(func1_with_exception, func2_no_exception) def test_func(x): return x diff --git a/tests/utils/test_metrics_decorator.py b/tests/utils/test_metrics_decorator.py index 8c3a8af9..3ad34684 100644 --- a/tests/utils/test_metrics_decorator.py +++ b/tests/utils/test_metrics_decorator.py @@ -10,6 +10,7 @@ from swarms.utils import metrics_decorator # Basic successful test def test_metrics_decorator_success(): + @metrics_decorator def decorated_func(): time.sleep(0.1) @@ -30,8 +31,8 @@ def test_metrics_decorator_success(): ], ) def test_metrics_decorator_with_various_wait_times_and_return_vals( - wait_time, return_val -): + wait_time, return_val): + @metrics_decorator def decorated_func(): time.sleep(wait_time) @@ -55,19 +56,17 @@ def test_metrics_decorator_with_mocked_time(mocker): return ["tok_1", "tok_2"] metrics = decorated_func() - assert ( - metrics - == """ + assert (metrics == """ Time to First Token: 5 Generation Latency: 20 Throughput: 0.1 - """ - ) + """) mocked_time.assert_any_call() # Test to ensure that exceptions in the decorated function are propagated def test_metrics_decorator_raises_exception(): + @metrics_decorator def decorated_func(): raise ValueError("Oops!") @@ -78,6 +77,7 @@ def test_metrics_decorator_raises_exception(): # Test to ensure proper handling when decorated function returns non-list value def test_metrics_decorator_with_non_list_return_val(): + @metrics_decorator def decorated_func(): return "Hello, world!" diff --git a/tests/utils/test_pdf_to_text.py b/tests/utils/test_pdf_to_text.py index 257364b4..115a58b1 100644 --- a/tests/utils/test_pdf_to_text.py +++ b/tests/utils/test_pdf_to_text.py @@ -29,8 +29,8 @@ def test_passing_non_pdf_file(tmpdir): file = tmpdir.join("temp.txt") file.write("This is a test") with pytest.raises( - Exception, - match=r"An error occurred while reading the PDF file", + Exception, + match=r"An error occurred while reading the PDF file", ): pdf_to_text(str(file)) diff --git a/tests/utils/test_prep_torch_inference.py b/tests/utils/test_prep_torch_inference.py index 6af4a9a7..8fcf7666 100644 --- a/tests/utils/test_prep_torch_inference.py +++ b/tests/utils/test_prep_torch_inference.py @@ -9,16 +9,13 @@ from swarms.utils import prep_torch_inference def test_prep_torch_inference(): model_path = "model_path" - device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_mock = Mock() model_mock.eval = Mock() # Mocking the load_model_torch function to return our mock model. - with unittest.mock.patch( - "swarms.utils.load_model_torch", return_value=model_mock - ) as _: + with unittest.mock.patch("swarms.utils.load_model_torch", + return_value=model_mock) as _: model = prep_torch_inference(model_path, device) # Check if model was properly loaded and eval function was called diff --git a/tests/utils/test_prep_torch_model_inference.py b/tests/utils/test_prep_torch_model_inference.py index 07da4e97..fb93cfb9 100644 --- a/tests/utils/test_prep_torch_model_inference.py +++ b/tests/utils/test_prep_torch_model_inference.py @@ -3,8 +3,7 @@ from unittest.mock import MagicMock import torch from swarms.utils.prep_torch_model_inference import ( - prep_torch_inference, -) + prep_torch_inference,) def test_prep_torch_inference_no_model_path(): diff --git a/tests/utils/test_print_class_parameters.py b/tests/utils/test_print_class_parameters.py index 9a133ae4..4fde69e5 100644 --- a/tests/utils/test_print_class_parameters.py +++ b/tests/utils/test_print_class_parameters.py @@ -4,27 +4,30 @@ from swarms.utils import print_class_parameters class TestObject: + def __init__(self, value1, value2: int): pass class TestObject2: + def __init__(self: "TestObject2", value1, value2: int = 5): pass def test_class_with_complex_parameters(): + class ComplexArgs: + def __init__(self, value1: list, value2: dict = {}): pass output = {"value1": "", "value2": ""} - assert ( - print_class_parameters(ComplexArgs, api_format=True) == output - ) + assert (print_class_parameters(ComplexArgs, api_format=True) == output) def test_empty_class(): + class Empty: pass @@ -33,7 +36,9 @@ def test_empty_class(): def test_class_with_no_annotations(): + class NoAnnotations: + def __init__(self, value1, value2): pass @@ -41,14 +46,13 @@ def test_class_with_no_annotations(): "value1": "", "value2": "", } - assert ( - print_class_parameters(NoAnnotations, api_format=True) - == output - ) + assert (print_class_parameters(NoAnnotations, api_format=True) == output) def test_class_with_partial_annotations(): + class PartialAnnotations: + def __init__(self, value1, value2: int): pass @@ -56,10 +60,8 @@ def test_class_with_partial_annotations(): "value1": "", "value2": "", } - assert ( - print_class_parameters(PartialAnnotations, api_format=True) - == output - ) + assert (print_class_parameters(PartialAnnotations, + api_format=True) == output) @pytest.mark.parametrize( diff --git a/tests/utils/test_subprocess_code_interpreter.py b/tests/utils/test_subprocess_code_interpreter.py index 3bb800f5..8522a59a 100644 --- a/tests/utils/test_subprocess_code_interpreter.py +++ b/tests/utils/test_subprocess_code_interpreter.py @@ -5,8 +5,7 @@ import threading import pytest from swarms.utils.code_interpreter import ( # Adjust the import according to your project structure - SubprocessCodeInterpreter, -) + SubprocessCodeInterpreter,) # Fixture for the SubprocessCodeInterpreter instance @@ -30,9 +29,8 @@ def test_start_and_terminate_process(interpreter): interpreter.start_process() assert isinstance(interpreter.process, subprocess.Popen) interpreter.terminate() - assert ( - interpreter.process.poll() is not None - ) # Process should be terminated + assert (interpreter.process.poll() + is not None) # Process should be terminated # Test preprocess_code method @@ -46,25 +44,22 @@ def test_preprocess_code(interpreter): # Test detect_active_line method def test_detect_active_line(interpreter): line = "Some line of code" - assert ( - interpreter.detect_active_line(line) is None - ) # Adjust assertion based on implementation + assert (interpreter.detect_active_line(line) + is None) # Adjust assertion based on implementation # Test detect_end_of_execution method def test_detect_end_of_execution(interpreter): line = "End of execution line" - assert ( - interpreter.detect_end_of_execution(line) is None - ) # Adjust assertion based on implementation + assert (interpreter.detect_end_of_execution(line) + is None) # Adjust assertion based on implementation # Test line_postprocessor method def test_line_postprocessor(interpreter): line = "Some output line" - assert ( - interpreter.line_postprocessor(line) == line - ) # Adjust assertion based on implementation + assert (interpreter.line_postprocessor(line) == line + ) # Adjust assertion based on implementation # Test handle_stream_output method diff --git a/tests/utils/test_try_except_wrapper.py b/tests/utils/test_try_except_wrapper.py index 26b509fb..d815988e 100644 --- a/tests/utils/test_try_except_wrapper.py +++ b/tests/utils/test_try_except_wrapper.py @@ -2,44 +2,44 @@ from swarms.utils.try_except_wrapper import try_except_wrapper def test_try_except_wrapper_with_no_exception(): + @try_except_wrapper def add(x, y): return x + y result = add(1, 2) - assert ( - result == 3 - ), "The function should return the sum of the arguments" + assert (result == 3), "The function should return the sum of the arguments" def test_try_except_wrapper_with_exception(): + @try_except_wrapper def divide(x, y): return x / y result = divide(1, 0) assert ( - result is None - ), "The function should return None when an exception is raised" + result + is None), "The function should return None when an exception is raised" def test_try_except_wrapper_with_multiple_arguments(): + @try_except_wrapper def concatenate(*args): return "".join(args) result = concatenate("Hello", " ", "world") - assert ( - result == "Hello world" - ), "The function should concatenate the arguments" + assert (result == "Hello world" + ), "The function should concatenate the arguments" def test_try_except_wrapper_with_keyword_arguments(): + @try_except_wrapper def greet(name="world"): return f"Hello, {name}" result = greet(name="Alice") - assert ( - result == "Hello, Alice" - ), "The function should use the keyword arguments" + assert (result == "Hello, Alice" + ), "The function should use the keyword arguments"