From 5b121a74451e122891b3d591508c553dc034699a Mon Sep 17 00:00:00 2001 From: Joshua David Date: Fri, 1 Mar 2024 17:43:43 -0800 Subject: [PATCH] run pre-commit code_quality check --- .gitignore | 2 +- ...tion_agent.py => multion_agent_example.py} | 0 ....py => perimeter_defense_agent_example.py} | 0 ..._protocol.py => swarm_protocol_example.py} | 0 .../{tool_agent.py => tool_agent_example.py} | 0 .../demos/developer_swarm/main_example.py | 2 +- .../{simple_rag.py => simple_rag_example.py} | 0 .../memory/{qdrant.py => qdrant_example.py} | 0 ...zure_openai.py => azure_openai_example.py} | 0 .../models/{miqu.py => miqu_example.py} | 0 ...ze.py => agent_basic_customize_example.py} | 0 ... => agent_with_longterm_memory_example.py} | 0 ... basic_agent_with_azure_openai_example.py} | 0 ...hackathon.py => kyle_hackathon_example.py} | 0 ...y_voting.py => majority_voting_example.py} | 0 ...nt.py => multi_modal_rag_agent_example.py} | 0 ...das_to_str.py => pandas_to_str_example.py} | 0 swarms/memory/lanchain_chroma.py | 2 +- tests/agents/test_multion.py | 24 +- tests/agents/test_tool_agent.py | 81 ++--- tests/memory/test_dictinternalmemory.py | 5 +- tests/memory/test_dictsharedmemory.py | 26 +- .../test_langchainchromavectormemory.py | 30 +- tests/memory/test_pinecone.py | 6 +- tests/memory/test_pq_db.py | 9 +- tests/memory/test_qdrant.py | 12 +- tests/memory/test_short_term_memory.py | 64 ++-- tests/memory/test_sqlite.py | 16 +- tests/memory/test_weaviate.py | 103 +++--- tests/models/test_anthropic.py | 65 ++-- tests/models/test_biogpt.py | 33 +- tests/models/test_cohere.py | 285 +++++++++++----- tests/models/test_elevenlab.py | 46 ++- tests/models/test_fire_function_caller.py | 4 +- tests/models/test_fuyu.py | 37 +- tests/models/test_gemini.py | 63 ++-- tests/models/test_gigabind.py | 59 ++-- tests/models/test_gpt4_vision_api.py | 118 ++++--- tests/models/test_hf.py | 8 +- tests/models/test_hf_pipeline.py | 16 +- tests/models/test_huggingface.py | 35 +- tests/models/test_idefics.py | 88 +++-- tests/models/test_kosmos.py | 41 ++- tests/models/test_llama_function_caller.py | 16 +- tests/models/test_mixtral.py | 16 +- tests/models/test_mpt7b.py | 16 +- tests/models/test_nougat.py | 44 ++- tests/models/test_qwen.py | 40 ++- tests/models/test_speech_t5.py | 46 ++- tests/models/test_ssd_1b.py | 23 +- tests/models/test_timm.py | 10 +- tests/models/test_timm_model.py | 7 +- tests/models/test_togther.py | 19 +- tests/models/test_ultralytics.py | 4 +- tests/models/test_vilt.py | 12 +- tests/models/test_yi_200k.py | 45 ++- tests/models/test_zeroscope.py | 43 ++- tests/structs/test_agent.py | 321 ++++++++++++------ tests/structs/test_autoscaler.py | 8 +- tests/structs/test_base.py | 59 ++-- tests/structs/test_base_workflow.py | 9 +- tests/structs/test_concurrent_workflow.py | 4 +- tests/structs/test_conversation.py | 19 +- tests/structs/test_groupchat.py | 27 +- tests/structs/test_json.py | 7 +- tests/structs/test_majority_voting.py | 36 +- tests/structs/test_message_pool.py | 80 +++-- tests/structs/test_model_parallizer.py | 59 ++-- tests/structs/test_multi_agent_collab.py | 13 +- tests/structs/test_nonlinear_workflow.py | 5 +- tests/structs/test_recursive_workflow.py | 4 +- tests/structs/test_sequential_workflow.py | 9 +- tests/structs/test_swarmnetwork.py | 4 +- tests/structs/test_task.py | 39 ++- tests/structs/test_taskqueuebase.py | 7 +- tests/structs/test_team.py | 12 +- tests/structs/test_tests_graph_workflow.py | 20 +- tests/telemetry/test_posthog_utils.py | 5 +- tests/telemetry/test_user_utils.py | 8 +- tests/test_upload_tests_to_issues.py | 8 +- tests/tokenizers/test_anthropictokenizer.py | 5 +- tests/tokenizers/test_basetokenizer.py | 8 +- tests/tokenizers/test_huggingfacetokenizer.py | 16 +- tests/tokenizers/test_openaitokenizer.py | 24 +- tests/tokenizers/test_tokenizer.py | 53 ++- tests/tools/test_tools_base.py | 77 ++--- tests/utils/test_check_device.py | 21 +- tests/utils/test_class_args_wrapper.py | 7 +- tests/utils/test_device.py | 46 ++- tests/utils/test_display_markdown_message.py | 6 +- .../utils/test_extract_code_from_markdown.py | 15 +- tests/utils/test_find_image_path.py | 12 +- tests/utils/test_limit_tokens_from_string.py | 28 +- tests/utils/test_load_model_torch.py | 32 +- tests/utils/test_load_models_torch.py | 10 +- tests/utils/test_math_eval.py | 2 - tests/utils/test_metrics_decorator.py | 14 +- tests/utils/test_pdf_to_text.py | 4 +- tests/utils/test_prep_torch_inference.py | 9 +- .../utils/test_prep_torch_model_inference.py | 3 +- tests/utils/test_print_class_parameters.py | 24 +- .../utils/test_subprocess_code_interpreter.py | 23 +- tests/utils/test_try_except_wrapper.py | 22 +- 103 files changed, 1781 insertions(+), 1064 deletions(-) rename playground/agents/{multion_agent.py => multion_agent_example.py} (100%) rename playground/agents/{perimeter_defense_agent.py => perimeter_defense_agent_example.py} (100%) rename playground/agents/{swarm_protocol.py => swarm_protocol_example.py} (100%) rename playground/agents/{tool_agent.py => tool_agent_example.py} (100%) rename playground/demos/simple_rag/{simple_rag.py => simple_rag_example.py} (100%) rename playground/memory/{qdrant.py => qdrant_example.py} (100%) rename playground/models/{azure_openai.py => azure_openai_example.py} (100%) rename playground/models/{miqu.py => miqu_example.py} (100%) rename playground/structs/{agent_basic_customize.py => agent_basic_customize_example.py} (100%) rename playground/structs/{agent_with_longterm_memory.py => agent_with_longterm_memory_example.py} (100%) rename playground/structs/{basic_agent_with_azure_openai.py => basic_agent_with_azure_openai_example.py} (100%) rename playground/structs/{kyle_hackathon.py => kyle_hackathon_example.py} (100%) rename playground/structs/{majority_voting.py => majority_voting_example.py} (100%) rename playground/structs/{multi_modal_rag_agent.py => multi_modal_rag_agent_example.py} (100%) rename playground/utils/{pandas_to_str.py => pandas_to_str_example.py} (100%) diff --git a/.gitignore b/.gitignore index e24110aa..d9d1aa3b 100644 --- a/.gitignore +++ b/.gitignore @@ -110,7 +110,7 @@ docs/_build/ # PyBuilder .pybuilder/ target/ -` + # Jupyter Notebook .ipynb_checkpoints diff --git a/playground/agents/multion_agent.py b/playground/agents/multion_agent_example.py similarity index 100% rename from playground/agents/multion_agent.py rename to playground/agents/multion_agent_example.py diff --git a/playground/agents/perimeter_defense_agent.py b/playground/agents/perimeter_defense_agent_example.py similarity index 100% rename from playground/agents/perimeter_defense_agent.py rename to playground/agents/perimeter_defense_agent_example.py diff --git a/playground/agents/swarm_protocol.py b/playground/agents/swarm_protocol_example.py similarity index 100% rename from playground/agents/swarm_protocol.py rename to playground/agents/swarm_protocol_example.py diff --git a/playground/agents/tool_agent.py b/playground/agents/tool_agent_example.py similarity index 100% rename from playground/agents/tool_agent.py rename to playground/agents/tool_agent_example.py diff --git a/playground/demos/developer_swarm/main_example.py b/playground/demos/developer_swarm/main_example.py index 0a2e2a95..c930f3e9 100644 --- a/playground/demos/developer_swarm/main_example.py +++ b/playground/demos/developer_swarm/main_example.py @@ -33,7 +33,7 @@ CODE """ # Initialize the language model -llm = OpenAIChat(openai_api_key=api_key, max_tokens=5000) +llm = OpenAIChat(openai_api_key=api_key, max_tokens=4096) # Documentation agent diff --git a/playground/demos/simple_rag/simple_rag.py b/playground/demos/simple_rag/simple_rag_example.py similarity index 100% rename from playground/demos/simple_rag/simple_rag.py rename to playground/demos/simple_rag/simple_rag_example.py diff --git a/playground/memory/qdrant.py b/playground/memory/qdrant_example.py similarity index 100% rename from playground/memory/qdrant.py rename to playground/memory/qdrant_example.py diff --git a/playground/models/azure_openai.py b/playground/models/azure_openai_example.py similarity index 100% rename from playground/models/azure_openai.py rename to playground/models/azure_openai_example.py diff --git a/playground/models/miqu.py b/playground/models/miqu_example.py similarity index 100% rename from playground/models/miqu.py rename to playground/models/miqu_example.py diff --git a/playground/structs/agent_basic_customize.py b/playground/structs/agent_basic_customize_example.py similarity index 100% rename from playground/structs/agent_basic_customize.py rename to playground/structs/agent_basic_customize_example.py diff --git a/playground/structs/agent_with_longterm_memory.py b/playground/structs/agent_with_longterm_memory_example.py similarity index 100% rename from playground/structs/agent_with_longterm_memory.py rename to playground/structs/agent_with_longterm_memory_example.py diff --git a/playground/structs/basic_agent_with_azure_openai.py b/playground/structs/basic_agent_with_azure_openai_example.py similarity index 100% rename from playground/structs/basic_agent_with_azure_openai.py rename to playground/structs/basic_agent_with_azure_openai_example.py diff --git a/playground/structs/kyle_hackathon.py b/playground/structs/kyle_hackathon_example.py similarity index 100% rename from playground/structs/kyle_hackathon.py rename to playground/structs/kyle_hackathon_example.py diff --git a/playground/structs/majority_voting.py b/playground/structs/majority_voting_example.py similarity index 100% rename from playground/structs/majority_voting.py rename to playground/structs/majority_voting_example.py diff --git a/playground/structs/multi_modal_rag_agent.py b/playground/structs/multi_modal_rag_agent_example.py similarity index 100% rename from playground/structs/multi_modal_rag_agent.py rename to playground/structs/multi_modal_rag_agent_example.py diff --git a/playground/utils/pandas_to_str.py b/playground/utils/pandas_to_str_example.py similarity index 100% rename from playground/utils/pandas_to_str.py rename to playground/utils/pandas_to_str_example.py diff --git a/swarms/memory/lanchain_chroma.py b/swarms/memory/lanchain_chroma.py index cf846d31..a053ba52 100644 --- a/swarms/memory/lanchain_chroma.py +++ b/swarms/memory/lanchain_chroma.py @@ -3,7 +3,7 @@ from pathlib import Path from langchain.chains import RetrievalQA from langchain.chains.question_answering import load_qa_chain -from langchain.embeddings.openai import OpenAIEmbeddings +from langchain_community.embeddings import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain_community.vectorstores import Chroma diff --git a/tests/agents/test_multion.py b/tests/agents/test_multion.py index 66399676..23614934 100644 --- a/tests/agents/test_multion.py +++ b/tests/agents/test_multion.py @@ -23,15 +23,19 @@ def test_multion_agent_run(mock_multion): assert result == "result" assert status == "status" assert last_url == "lastUrl" - mock_multion.browse.assert_called_once_with({ - "cmd": "task", - "url": "https://www.example.com", - "maxSteps": 5, - }) + mock_multion.browse.assert_called_once_with( + { + "cmd": "task", + "url": "https://www.example.com", + "maxSteps": 5, + } + ) # Additional tests for different tasks -@pytest.mark.parametrize("task", ["task1", "task2", "task3", "task4", "task5"]) +@pytest.mark.parametrize( + "task", ["task1", "task2", "task3", "task4", "task5"] +) @patch("swarms.agents.multion_agent.multion") def test_multion_agent_run_different_tasks(mock_multion, task): mock_response = MagicMock() @@ -50,8 +54,6 @@ def test_multion_agent_run_different_tasks(mock_multion, task): assert result == "result" assert status == "status" assert last_url == "lastUrl" - mock_multion.browse.assert_called_once_with({ - "cmd": task, - "url": "https://www.example.com", - "maxSteps": 5 - }) + mock_multion.browse.assert_called_once_with( + {"cmd": task, "url": "https://www.example.com", "maxSteps": 5} + ) diff --git a/tests/agents/test_tool_agent.py b/tests/agents/test_tool_agent.py index 4ee71a22..691489c0 100644 --- a/tests/agents/test_tool_agent.py +++ b/tests/agents/test_tool_agent.py @@ -11,27 +11,18 @@ def test_tool_agent_init(): json_schema = { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "number" - }, - "is_student": { - "type": "boolean" - }, - "courses": { - "type": "array", - "items": { - "type": "string" - } - }, + "name": {"type": "string"}, + "age": {"type": "number"}, + "is_student": {"type": "boolean"}, + "courses": {"type": "array", "items": {"type": "string"}}, }, } name = "Test Agent" description = "This is a test agent" - agent = ToolAgent(name, description, model, tokenizer, json_schema) + agent = ToolAgent( + name, description, model, tokenizer, json_schema + ) assert agent.name == name assert agent.description == description @@ -47,29 +38,22 @@ def test_tool_agent_run(mock_run): json_schema = { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "number" - }, - "is_student": { - "type": "boolean" - }, - "courses": { - "type": "array", - "items": { - "type": "string" - } - }, + "name": {"type": "string"}, + "age": {"type": "number"}, + "is_student": {"type": "boolean"}, + "courses": {"type": "array", "items": {"type": "string"}}, }, } name = "Test Agent" description = "This is a test agent" - task = ("Generate a person's information based on the following" - " schema:") + task = ( + "Generate a person's information based on the following" + " schema:" + ) - agent = ToolAgent(name, description, model, tokenizer, json_schema) + agent = ToolAgent( + name, description, model, tokenizer, json_schema + ) agent.run(task) mock_run.assert_called_once_with(task) @@ -81,21 +65,10 @@ def test_tool_agent_init_with_kwargs(): json_schema = { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "number" - }, - "is_student": { - "type": "boolean" - }, - "courses": { - "type": "array", - "items": { - "type": "string" - } - }, + "name": {"type": "string"}, + "age": {"type": "number"}, + "is_student": {"type": "boolean"}, + "courses": {"type": "array", "items": {"type": "string"}}, }, } name = "Test Agent" @@ -109,8 +82,9 @@ def test_tool_agent_init_with_kwargs(): "max_string_token_length": 20, } - agent = ToolAgent(name, description, model, tokenizer, json_schema, - **kwargs) + agent = ToolAgent( + name, description, model, tokenizer, json_schema, **kwargs + ) assert agent.name == name assert agent.description == description @@ -121,4 +95,7 @@ def test_tool_agent_init_with_kwargs(): assert agent.max_array_length == kwargs["max_array_length"] assert agent.max_number_tokens == kwargs["max_number_tokens"] assert agent.temperature == kwargs["temperature"] - assert (agent.max_string_token_length == kwargs["max_string_token_length"]) + assert ( + agent.max_string_token_length + == kwargs["max_string_token_length"] + ) diff --git a/tests/memory/test_dictinternalmemory.py b/tests/memory/test_dictinternalmemory.py index 409acbf2..7658eb7c 100644 --- a/tests/memory/test_dictinternalmemory.py +++ b/tests/memory/test_dictinternalmemory.py @@ -33,8 +33,9 @@ def test_memory_limit_enforced(memory): # Parameterized Tests -@pytest.mark.parametrize("scores, best_score", [([10, 5, 3], 10), - ([1, 2, 3], 3)]) +@pytest.mark.parametrize( + "scores, best_score", [([10, 5, 3], 10), ([1, 2, 3], 3)] +) def test_get_top_n(scores, best_score, memory): for score in scores: memory.add(score, {"data": f"test{score}"}) diff --git a/tests/memory/test_dictsharedmemory.py b/tests/memory/test_dictsharedmemory.py index f391cdc2..a41ccd8f 100644 --- a/tests/memory/test_dictsharedmemory.py +++ b/tests/memory/test_dictsharedmemory.py @@ -26,7 +26,8 @@ def memory_instance(memory_file): def test_init(memory_file): memory = DictSharedMemory(file_loc=memory_file) assert os.path.exists( - memory.file_loc), "Memory file should be created if non-existent" + memory.file_loc + ), "Memory file should be created if non-existent" def test_add_entry(memory_instance): @@ -43,8 +44,9 @@ def test_get_top_n(memory_instance): memory_instance.add(9.5, "agent123", 1, "Entry A") memory_instance.add(8.5, "agent124", 1, "Entry B") top_1 = memory_instance.get_top_n(1) - assert (len(top_1) == 1 - ), "get_top_n should return the correct number of top entries" + assert ( + len(top_1) == 1 + ), "get_top_n should return the correct number of top entries" # Parameterized tests @@ -57,14 +59,18 @@ def test_get_top_n(memory_instance): # add more test cases ], ) -def test_parametrized_get_top_n(memory_instance, scores, agent_ids, - expected_top_score): +def test_parametrized_get_top_n( + memory_instance, scores, agent_ids, expected_top_score +): for score, agent_id in zip(scores, agent_ids): - memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}") + memory_instance.add( + score, agent_id, 1, f"Entry by {agent_id}" + ) top_1 = memory_instance.get_top_n(1) top_score = next(iter(top_1.values()))["score"] - assert (top_score == expected_top_score - ), "get_top_n should return the entry with top score" + assert ( + top_score == expected_top_score + ), "get_top_n should return the entry with top score" # Exception testing @@ -72,7 +78,9 @@ def test_parametrized_get_top_n(memory_instance, scores, agent_ids, def test_add_entry_invalid_input(memory_instance): with pytest.raises(ValueError): - memory_instance.add("invalid_score", "agent123", 1, "Test Entry") + memory_instance.add( + "invalid_score", "agent123", 1, "Test Entry" + ) # Mocks and monkey-patching diff --git a/tests/memory/test_langchainchromavectormemory.py b/tests/memory/test_langchainchromavectormemory.py index 25316ae2..ee882c6c 100644 --- a/tests/memory/test_langchainchromavectormemory.py +++ b/tests/memory/test_langchainchromavectormemory.py @@ -35,13 +35,16 @@ def qa_mock(): # Example test cases def test_initialization_default_settings(vector_memory): assert vector_memory.chunk_size == 1000 - assert (vector_memory.chunk_overlap == 100 - ) # assuming default overlap of 0.1 + assert ( + vector_memory.chunk_overlap == 100 + ) # assuming default overlap of 0.1 assert vector_memory.loc.exists() def test_add_entry(vector_memory, embeddings_mock): - with patch.object(vector_memory.db, "add_texts") as add_texts_mock: + with patch.object( + vector_memory.db, "add_texts" + ) as add_texts_mock: vector_memory.add("Example text") add_texts_mock.assert_called() @@ -74,17 +77,20 @@ def test_ask_question_returns_string(vector_memory, qa_mock): ), # Mocked object as a placeholder ], ) -def test_search_memory_different_params(vector_memory, query, k, type, - expected): +def test_search_memory_different_params( + vector_memory, query, k, type, expected +): with patch.object( - vector_memory.db, - "max_marginal_relevance_search", - return_value=expected, + vector_memory.db, + "max_marginal_relevance_search", + return_value=expected, ): with patch.object( - vector_memory.db, - "similarity_search_with_score", - return_value=expected, + vector_memory.db, + "similarity_search_with_score", + return_value=expected, ): - result = vector_memory.search_memory(query, k=k, type=type) + result = vector_memory.search_memory( + query, k=k, type=type + ) assert len(result) == (k if k > 0 else 0) diff --git a/tests/memory/test_pinecone.py b/tests/memory/test_pinecone.py index bedce998..a7d4fcea 100644 --- a/tests/memory/test_pinecone.py +++ b/tests/memory/test_pinecone.py @@ -8,7 +8,8 @@ api_key = os.getenv("PINECONE_API_KEY") or "" def test_init(): with patch("pinecone.init") as MockInit, patch( - "pinecone.Index") as MockIndex: + "pinecone.Index" + ) as MockIndex: store = PineconeDB( api_key=api_key, index_name="test_index", @@ -70,7 +71,8 @@ def test_query(): def test_create_index(): with patch("pinecone.init"), patch("pinecone.Index"), patch( - "pinecone.create_index") as MockCreateIndex: + "pinecone.create_index" + ) as MockCreateIndex: store = PineconeDB( api_key=api_key, index_name="test_index", diff --git a/tests/memory/test_pq_db.py b/tests/memory/test_pq_db.py index e3e0925c..5e44f0ba 100644 --- a/tests/memory/test_pq_db.py +++ b/tests/memory/test_pq_db.py @@ -32,7 +32,8 @@ def test_create_vector_model(): def test_add_or_update_vector(): with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session") as MockSession: + "sqlalchemy.orm.Session" + ) as MockSession: db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", @@ -50,7 +51,8 @@ def test_add_or_update_vector(): def test_query_vectors(): with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session") as MockSession: + "sqlalchemy.orm.Session" + ) as MockSession: db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", @@ -65,7 +67,8 @@ def test_query_vectors(): def test_delete_vector(): with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session") as MockSession: + "sqlalchemy.orm.Session" + ) as MockSession: db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", diff --git a/tests/memory/test_qdrant.py b/tests/memory/test_qdrant.py index f6023dd6..5f82814c 100644 --- a/tests/memory/test_qdrant.py +++ b/tests/memory/test_qdrant.py @@ -13,8 +13,9 @@ def mock_qdrant_client(): @pytest.fixture def mock_sentence_transformer(): - with patch("sentence_transformers.SentenceTransformer" - ) as MockSentenceTransformer: + with patch( + "sentence_transformers.SentenceTransformer" + ) as MockSentenceTransformer: yield MockSentenceTransformer() @@ -28,7 +29,9 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client): assert qdrant_client.client is not None -def test_load_embedding_model(qdrant_client, mock_sentence_transformer): +def test_load_embedding_model( + qdrant_client, mock_sentence_transformer +): qdrant_client._load_embedding_model("model_name") mock_sentence_transformer.assert_called_once_with("model_name") @@ -36,7 +39,8 @@ def test_load_embedding_model(qdrant_client, mock_sentence_transformer): def test_setup_collection(qdrant_client, mock_qdrant_client): qdrant_client._setup_collection() mock_qdrant_client.get_collection.assert_called_once_with( - qdrant_client.collection_name) + qdrant_client.collection_name + ) def test_add_vectors(qdrant_client, mock_qdrant_client): diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index 2dbd9fc9..132da5f6 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -12,29 +12,26 @@ def test_init(): def test_add(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert memory.short_term_memory == [{ - "role": "user", - "message": "Hello, world!" - }] + assert memory.short_term_memory == [ + {"role": "user", "message": "Hello, world!"} + ] def test_get_short_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert memory.get_short_term() == [{ - "role": "user", - "message": "Hello, world!" - }] + assert memory.get_short_term() == [ + {"role": "user", "message": "Hello, world!"} + ] def test_get_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) - assert memory.get_medium_term() == [{ - "role": "user", - "message": "Hello, world!" - }] + assert memory.get_medium_term() == [ + {"role": "user", "message": "Hello, world!"} + ] def test_clear_medium_term(): @@ -48,18 +45,19 @@ def test_clear_medium_term(): def test_get_short_term_memory_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert (memory.get_short_term_memory_str() == - "[{'role': 'user', 'message': 'Hello, world!'}]") + assert ( + memory.get_short_term_memory_str() + == "[{'role': 'user', 'message': 'Hello, world!'}]" + ) def test_update_short_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.update_short_term(0, "user", "Goodbye, world!") - assert memory.get_short_term() == [{ - "role": "user", - "message": "Goodbye, world!" - }] + assert memory.get_short_term() == [ + {"role": "user", "message": "Goodbye, world!"} + ] def test_clear(): @@ -73,10 +71,9 @@ def test_search_memory(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.search_memory("Hello") == { - "short_term": [(0, { - "role": "user", - "message": "Hello, world!" - })], + "short_term": [ + (0, {"role": "user", "message": "Hello, world!"}) + ], "medium_term": [], } @@ -84,18 +81,19 @@ def test_search_memory(): def test_return_shortmemory_as_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") - assert (memory.return_shortmemory_as_str() == - "[{'role': 'user', 'message': 'Hello, world!'}]") + assert ( + memory.return_shortmemory_as_str() + == "[{'role': 'user', 'message': 'Hello, world!'}]" + ) def test_move_to_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) - assert memory.get_medium_term() == [{ - "role": "user", - "message": "Hello, world!" - }] + assert memory.get_medium_term() == [ + {"role": "user", "message": "Hello, world!"} + ] assert memory.get_short_term() == [] @@ -103,8 +101,10 @@ def test_return_medium_memory_as_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) - assert (memory.return_medium_memory_as_str() == - "[{'role': 'user', 'message': 'Hello, world!'}]") + assert ( + memory.return_medium_memory_as_str() + == "[{'role': 'user', 'message': 'Hello, world!'}]" + ) def test_thread_safety(): @@ -114,7 +114,9 @@ def test_thread_safety(): for _ in range(1000): memory.add("user", "Hello, world!") - threads = [threading.Thread(target=add_messages) for _ in range(10)] + threads = [ + threading.Thread(target=add_messages) for _ in range(10) + ] for thread in threads: thread.start() for thread in threads: diff --git a/tests/memory/test_sqlite.py b/tests/memory/test_sqlite.py index 50808817..49d61ef7 100644 --- a/tests/memory/test_sqlite.py +++ b/tests/memory/test_sqlite.py @@ -8,7 +8,9 @@ from swarms.memory.sqlite import SQLiteDB @pytest.fixture def db(): conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + conn.execute( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + ) conn.commit() return SQLiteDB(":memory:") @@ -28,7 +30,9 @@ def test_delete(db): def test_update(db): db.add("INSERT INTO test (name) VALUES (?)", ("test",)) - db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test")) + db.update( + "UPDATE test SET name = ? WHERE name = ?", ("new", "test") + ) result = db.query("SELECT * FROM test") assert result == [(1, "new")] @@ -41,7 +45,9 @@ def test_query(db): def test_execute_query(db): db.add("INSERT INTO test (name) VALUES (?)", ("test",)) - result = db.execute_query("SELECT * FROM test WHERE name = ?", ("test",)) + result = db.execute_query( + "SELECT * FROM test WHERE name = ?", ("test",) + ) assert result == [(1, "test")] @@ -95,4 +101,6 @@ def test_query_with_wrong_query(db): def test_execute_query_with_wrong_query(db): with pytest.raises(sqlite3.OperationalError): - db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",)) + db.execute_query( + "SELECT * FROM wrong WHERE name = ?", ("test",) + ) diff --git a/tests/memory/test_weaviate.py b/tests/memory/test_weaviate.py index f05d1b50..d1a69da0 100644 --- a/tests/memory/test_weaviate.py +++ b/tests/memory/test_weaviate.py @@ -16,7 +16,9 @@ def weaviate_client_mock(): grpc_port="mock_grpc_port", grpc_secure=False, auth_client_secret="mock_api_key", - additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"}, + additional_headers={ + "X-OpenAI-Api-Key": "mock_openai_api_key" + }, additional_config=Mock(), ) @@ -34,15 +36,13 @@ def weaviate_client_mock(): # Define tests for the WeaviateDB class def test_create_collection(weaviate_client_mock): # Test creating a collection - weaviate_client_mock.create_collection("test_collection", [{ - "name": "property" - }]) + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}] + ) weaviate_client_mock.client.collections.create.assert_called_with( name="test_collection", vectorizer_config=None, - properties=[{ - "name": "property" - }], + properties=[{"name": "property"}], ) @@ -51,9 +51,11 @@ def test_add_object(weaviate_client_mock): properties = {"name": "John"} weaviate_client_mock.add("test_collection", properties) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection") + "test_collection" + ) weaviate_client_mock.client.collections.data.insert.assert_called_with( - properties) + properties + ) def test_query_objects(weaviate_client_mock): @@ -61,20 +63,26 @@ def test_query_objects(weaviate_client_mock): query = "name:John" weaviate_client_mock.query("test_collection", query) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection") + "test_collection" + ) weaviate_client_mock.client.collections.query.bm25.assert_called_with( - query=query, limit=10) + query=query, limit=10 + ) def test_update_object(weaviate_client_mock): # Test updating an object object_id = "12345" properties = {"name": "Jane"} - weaviate_client_mock.update("test_collection", object_id, properties) + weaviate_client_mock.update( + "test_collection", object_id, properties + ) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection") + "test_collection" + ) weaviate_client_mock.client.collections.data.update.assert_called_with( - object_id, properties) + object_id, properties + ) def test_delete_object(weaviate_client_mock): @@ -82,23 +90,25 @@ def test_delete_object(weaviate_client_mock): object_id = "12345" weaviate_client_mock.delete("test_collection", object_id) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection") + "test_collection" + ) weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with( - object_id) + object_id + ) -def test_create_collection_with_vectorizer_config(weaviate_client_mock,): +def test_create_collection_with_vectorizer_config( + weaviate_client_mock, +): # Test creating a collection with vectorizer configuration vectorizer_config = {"config_key": "config_value"} - weaviate_client_mock.create_collection("test_collection", [{ - "name": "property" - }], vectorizer_config) + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}], vectorizer_config + ) weaviate_client_mock.client.collections.create.assert_called_with( name="test_collection", vectorizer_config=vectorizer_config, - properties=[{ - "name": "property" - }], + properties=[{"name": "property"}], ) @@ -108,9 +118,11 @@ def test_query_objects_with_limit(weaviate_client_mock): limit = 20 weaviate_client_mock.query("test_collection", query, limit) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection") + "test_collection" + ) weaviate_client_mock.client.collections.query.bm25.assert_called_with( - query=query, limit=limit) + query=query, limit=limit + ) def test_query_objects_without_limit(weaviate_client_mock): @@ -118,29 +130,33 @@ def test_query_objects_without_limit(weaviate_client_mock): query = "name:John" weaviate_client_mock.query("test_collection", query) weaviate_client_mock.client.collections.get.assert_called_with( - "test_collection") + "test_collection" + ) weaviate_client_mock.client.collections.query.bm25.assert_called_with( - query=query, limit=10) + query=query, limit=10 + ) def test_create_collection_failure(weaviate_client_mock): # Test failure when creating a collection with patch( - "weaviate_client.weaviate.collections.create", - side_effect=Exception("Create error"), + "weaviate_client.weaviate.collections.create", + side_effect=Exception("Create error"), ): - with pytest.raises(Exception, match="Error creating collection"): - weaviate_client_mock.create_collection("test_collection", [{ - "name": "property" - }]) + with pytest.raises( + Exception, match="Error creating collection" + ): + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}] + ) def test_add_object_failure(weaviate_client_mock): # Test failure when adding an object properties = {"name": "John"} with patch( - "weaviate_client.weaviate.collections.data.insert", - side_effect=Exception("Insert error"), + "weaviate_client.weaviate.collections.data.insert", + side_effect=Exception("Insert error"), ): with pytest.raises(Exception, match="Error adding object"): weaviate_client_mock.add("test_collection", properties) @@ -150,8 +166,8 @@ def test_query_objects_failure(weaviate_client_mock): # Test failure when querying objects query = "name:John" with patch( - "weaviate_client.weaviate.collections.query.bm25", - side_effect=Exception("Query error"), + "weaviate_client.weaviate.collections.query.bm25", + side_effect=Exception("Query error"), ): with pytest.raises(Exception, match="Error querying objects"): weaviate_client_mock.query("test_collection", query) @@ -162,20 +178,21 @@ def test_update_object_failure(weaviate_client_mock): object_id = "12345" properties = {"name": "Jane"} with patch( - "weaviate_client.weaviate.collections.data.update", - side_effect=Exception("Update error"), + "weaviate_client.weaviate.collections.data.update", + side_effect=Exception("Update error"), ): with pytest.raises(Exception, match="Error updating object"): - weaviate_client_mock.update("test_collection", object_id, - properties) + weaviate_client_mock.update( + "test_collection", object_id, properties + ) def test_delete_object_failure(weaviate_client_mock): # Test failure when deleting an object object_id = "12345" with patch( - "weaviate_client.weaviate.collections.data.delete_by_id", - side_effect=Exception("Delete error"), + "weaviate_client.weaviate.collections.data.delete_by_id", + side_effect=Exception("Delete error"), ): with pytest.raises(Exception, match="Error deleting object"): weaviate_client_mock.delete("test_collection", object_id) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3ec19e8c..cc48479a 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -8,11 +8,12 @@ from swarms.models.anthropic import Anthropic # Mock the Anthropic API client for testing class MockAnthropicClient: - def __init__(self, *args, **kwargs): pass - def completions_create(self, prompt, stop_sequences, stream, **kwargs): + def completions_create( + self, prompt, stop_sequences, stream, **kwargs + ): return MockAnthropicResponse() @@ -45,7 +46,9 @@ def test_anthropic_init_default_values(anthropic_instance): assert anthropic_instance.streaming is False assert anthropic_instance.default_request_timeout == 600 assert ( - anthropic_instance.anthropic_api_url == "https://test.anthropic.com") + anthropic_instance.anthropic_api_url + == "https://test.anthropic.com" + ) assert anthropic_instance.anthropic_api_key == "test_api_key" @@ -76,8 +79,9 @@ def test_anthropic_default_params(anthropic_instance): } -def test_anthropic_run(mock_anthropic_env, mock_requests_post, - anthropic_instance): +def test_anthropic_run( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -101,8 +105,9 @@ def test_anthropic_run(mock_anthropic_env, mock_requests_post, ) -def test_anthropic_call(mock_anthropic_env, mock_requests_post, - anthropic_instance): +def test_anthropic_call( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -126,8 +131,9 @@ def test_anthropic_call(mock_anthropic_env, mock_requests_post, ) -def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post, - anthropic_instance): +def test_anthropic_exception_handling( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"error": "An error occurred"} mock_requests_post.return_value = mock_response @@ -142,7 +148,6 @@ def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post, class MockAnthropicResponse: - def __init__(self): self.completion = "Mocked Response from Anthropic" @@ -168,7 +173,9 @@ def test_anthropic_async_call_method(anthropic_instance): def test_anthropic_async_stream_method(anthropic_instance): - async_generator = anthropic_instance.async_stream("Translate to French.") + async_generator = anthropic_instance.async_stream( + "Translate to French." + ) for token in async_generator: assert isinstance(token, str) @@ -192,51 +199,63 @@ def test_anthropic_wrap_prompt(anthropic_instance): def test_anthropic_convert_prompt(anthropic_instance): prompt = "What is the meaning of life?" converted_prompt = anthropic_instance.convert_prompt(prompt) - assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) + assert converted_prompt.startswith( + anthropic_instance.HUMAN_PROMPT + ) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) def test_anthropic_call_with_stop(anthropic_instance): - response = anthropic_instance("Translate to French.", - stop=["stop1", "stop2"]) + response = anthropic_instance( + "Translate to French.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Anthropic" def test_anthropic_stream_with_stop(anthropic_instance): - generator = anthropic_instance.stream("Write a story.", - stop=["stop1", "stop2"]) + generator = anthropic_instance.stream( + "Write a story.", stop=["stop1", "stop2"] + ) for token in generator: assert isinstance(token, str) def test_anthropic_async_call_with_stop(anthropic_instance): - response = anthropic_instance.async_call("Tell me a joke.", - stop=["stop1", "stop2"]) + response = anthropic_instance.async_call( + "Tell me a joke.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Anthropic" def test_anthropic_async_stream_with_stop(anthropic_instance): - async_generator = anthropic_instance.async_stream("Translate to French.", - stop=["stop1", "stop2"]) + async_generator = anthropic_instance.async_stream( + "Translate to French.", stop=["stop1", "stop2"] + ) for token in async_generator: assert isinstance(token, str) -def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance,): +def test_anthropic_get_num_tokens_with_count_tokens( + anthropic_instance, +): anthropic_instance.count_tokens = Mock(return_value=10) text = "This is a test sentence." num_tokens = anthropic_instance.get_num_tokens(text) assert num_tokens == 10 -def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance,): +def test_anthropic_get_num_tokens_without_count_tokens( + anthropic_instance, +): del anthropic_instance.count_tokens with pytest.raises(NameError): text = "This is a test sentence." anthropic_instance.get_num_tokens(text) -def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance,): +def test_anthropic_wrap_prompt_without_human_ai_prompt( + anthropic_instance, +): del anthropic_instance.HUMAN_PROMPT del anthropic_instance.AI_PROMPT prompt = "What is the meaning of life?" diff --git a/tests/models/test_biogpt.py b/tests/models/test_biogpt.py index 8f27e62a..e6093729 100644 --- a/tests/models/test_biogpt.py +++ b/tests/models/test_biogpt.py @@ -48,8 +48,10 @@ def test_cell_biology_response(biogpt_instance): # 40. Test for a question about protein structure def test_protein_structure_response(biogpt_instance): - question = ("What's the difference between alpha helix and beta sheet" - " structures in proteins?") + question = ( + "What's the difference between alpha helix and beta sheet" + " structures in proteins?" + ) response = biogpt_instance(question) assert response assert isinstance(response, str) @@ -81,7 +83,9 @@ def test_bioinformatics_response(biogpt_instance): # 44. Test for a neuroscience question def test_neuroscience_response(biogpt_instance): - question = ("Explain the function of synapses in the nervous system.") + question = ( + "Explain the function of synapses in the nervous system." + ) response = biogpt_instance(question) assert response assert isinstance(response, str) @@ -104,11 +108,8 @@ def test_init(bio_gpt): def test_call(bio_gpt, monkeypatch): - def mock_pipeline(*args, **kwargs): - class MockGenerator: - def __call__(self, text, **kwargs): return ["Generated text"] @@ -166,7 +167,9 @@ def test_get_config_return_type(biogpt_instance): # 28. Test saving model functionality by checking if files are created @patch.object(BioGptForCausalLM, "save_pretrained") @patch.object(BioGptTokenizer, "save_pretrained") -def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): +def test_save_model( + mock_save_model, mock_save_tokenizer, biogpt_instance +): path = "test_path" biogpt_instance.save_model(path) mock_save_model.assert_called_once_with(path) @@ -176,7 +179,9 @@ def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): # 29. Test loading model from path @patch.object(BioGptForCausalLM, "from_pretrained") @patch.object(BioGptTokenizer, "from_pretrained") -def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance): +def test_load_from_path( + mock_load_model, mock_load_tokenizer, biogpt_instance +): path = "test_path" biogpt_instance.load_from_path(path) mock_load_model.assert_called_once_with(path) @@ -193,7 +198,9 @@ def test_print_model_metadata(biogpt_instance): # 31. Test that beam_search_decoding uses the correct number of beams @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): +def test_beam_search_decoding_num_beams( + mock_generate, biogpt_instance +): biogpt_instance.beam_search_decoding("test_sentence", num_beams=7) _, kwargs = mock_generate.call_args assert kwargs["num_beams"] == 7 @@ -201,8 +208,12 @@ def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): # 32. Test if beam_search_decoding handles early_stopping @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance): - biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False) +def test_beam_search_decoding_early_stopping( + mock_generate, biogpt_instance +): + biogpt_instance.beam_search_decoding( + "test_sentence", early_stopping=False + ) _, kwargs = mock_generate.call_args assert kwargs["early_stopping"] is False diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 347cb9fc..8a1147d3 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -42,7 +42,9 @@ def test_cohere_async_api_error_handling(cohere_instance): cohere_instance.model = "base" cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): - cohere_instance.async_call("Error handling with invalid API key.") + cohere_instance.async_call( + "Error handling with invalid API key." + ) def test_cohere_stream_api_error_handling(cohere_instance): @@ -51,7 +53,8 @@ def test_cohere_stream_api_error_handling(cohere_instance): cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): generator = cohere_instance.stream( - "Error handling with invalid API key.") + "Error handling with invalid API key." + ) for token in generator: pass @@ -91,26 +94,31 @@ def test_cohere_convert_prompt(cohere_instance): def test_cohere_call_with_stop(cohere_instance): - response = cohere_instance("Translate to French.", stop=["stop1", "stop2"]) + response = cohere_instance( + "Translate to French.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Cohere" def test_cohere_stream_with_stop(cohere_instance): - generator = cohere_instance.stream("Write a story.", - stop=["stop1", "stop2"]) + generator = cohere_instance.stream( + "Write a story.", stop=["stop1", "stop2"] + ) for token in generator: assert isinstance(token, str) def test_cohere_async_call_with_stop(cohere_instance): - response = cohere_instance.async_call("Tell me a joke.", - stop=["stop1", "stop2"]) + response = cohere_instance.async_call( + "Tell me a joke.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Cohere" def test_cohere_async_stream_with_stop(cohere_instance): - async_generator = cohere_instance.async_stream("Translate to French.", - stop=["stop1", "stop2"]) + async_generator = cohere_instance.async_stream( + "Translate to French.", stop=["stop1", "stop2"] + ) for token in async_generator: assert isinstance(token, str) @@ -166,8 +174,12 @@ def test_base_cohere_validate_environment_without_cohere(): # Test cases for benchmarking generations with various models def test_cohere_generate_with_command_light(cohere_instance): cohere_instance.model = "command-light" - response = cohere_instance("Generate text with Command Light model.") - assert response.startswith("Generated text with Command Light model") + response = cohere_instance( + "Generate text with Command Light model." + ) + assert response.startswith( + "Generated text with Command Light model" + ) def test_cohere_generate_with_command(cohere_instance): @@ -190,54 +202,74 @@ def test_cohere_generate_with_base(cohere_instance): def test_cohere_generate_with_embed_english_v2(cohere_instance): cohere_instance.model = "embed-english-v2.0" - response = cohere_instance("Generate embeddings with English v2.0 model.") - assert response.startswith("Generated embeddings with English v2.0 model") + response = cohere_instance( + "Generate embeddings with English v2.0 model." + ) + assert response.startswith( + "Generated embeddings with English v2.0 model" + ) def test_cohere_generate_with_embed_english_light_v2(cohere_instance): cohere_instance.model = "embed-english-light-v2.0" response = cohere_instance( - "Generate embeddings with English Light v2.0 model.") + "Generate embeddings with English Light v2.0 model." + ) assert response.startswith( - "Generated embeddings with English Light v2.0 model") + "Generated embeddings with English Light v2.0 model" + ) def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance( - "Generate embeddings with Multilingual v2.0 model.") + "Generate embeddings with Multilingual v2.0 model." + ) assert response.startswith( - "Generated embeddings with Multilingual v2.0 model") + "Generated embeddings with Multilingual v2.0 model" + ) def test_cohere_generate_with_embed_english_v3(cohere_instance): cohere_instance.model = "embed-english-v3.0" - response = cohere_instance("Generate embeddings with English v3.0 model.") - assert response.startswith("Generated embeddings with English v3.0 model") + response = cohere_instance( + "Generate embeddings with English v3.0 model." + ) + assert response.startswith( + "Generated embeddings with English v3.0 model" + ) def test_cohere_generate_with_embed_english_light_v3(cohere_instance): cohere_instance.model = "embed-english-light-v3.0" response = cohere_instance( - "Generate embeddings with English Light v3.0 model.") + "Generate embeddings with English Light v3.0 model." + ) assert response.startswith( - "Generated embeddings with English Light v3.0 model") + "Generated embeddings with English Light v3.0 model" + ) def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance( - "Generate embeddings with Multilingual v3.0 model.") + "Generate embeddings with Multilingual v3.0 model." + ) assert response.startswith( - "Generated embeddings with Multilingual v3.0 model") + "Generated embeddings with Multilingual v3.0 model" + ) -def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance,): +def test_cohere_generate_with_embed_multilingual_light_v3( + cohere_instance, +): cohere_instance.model = "embed-multilingual-light-v3.0" response = cohere_instance( - "Generate embeddings with Multilingual Light v3.0 model.") + "Generate embeddings with Multilingual Light v3.0 model." + ) assert response.startswith( - "Generated embeddings with Multilingual Light v3.0 model") + "Generated embeddings with Multilingual Light v3.0 model" + ) # Add more test cases to benchmark other models and functionalities @@ -267,13 +299,17 @@ def test_cohere_call_with_embed_english_v3_model(cohere_instance): assert isinstance(response, str) -def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance,): +def test_cohere_call_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance("Translate to French.") assert isinstance(response, str) -def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance,): +def test_cohere_call_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance("Translate to French.") assert isinstance(response, str) @@ -293,7 +329,9 @@ def test_cohere_call_with_long_prompt(cohere_instance): def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): cohere_instance.max_tokens = 10 - prompt = ("This is a test prompt that will exceed the max tokens limit.") + prompt = ( + "This is a test prompt that will exceed the max tokens limit." + ) with pytest.raises(ValueError): cohere_instance(prompt) @@ -326,14 +364,18 @@ def test_cohere_stream_with_embed_english_v3_model(cohere_instance): assert isinstance(token, str) -def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance,): +def test_cohere_stream_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" generator = cohere_instance.stream("Write a story.") for token in generator: assert isinstance(token, str) -def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance,): +def test_cohere_stream_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" generator = cohere_instance.stream("Write a story.") for token in generator: @@ -352,25 +394,33 @@ def test_cohere_async_call_with_base_model(cohere_instance): assert isinstance(response, str) -def test_cohere_async_call_with_embed_english_v2_model(cohere_instance,): +def test_cohere_async_call_with_embed_english_v2_model( + cohere_instance, +): cohere_instance.model = "embed-english-v2.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_english_v3_model(cohere_instance,): +def test_cohere_async_call_with_embed_english_v3_model( + cohere_instance, +): cohere_instance.model = "embed-english-v3.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance,): +def test_cohere_async_call_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance,): +def test_cohere_async_call_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) @@ -390,28 +440,36 @@ def test_cohere_async_stream_with_base_model(cohere_instance): assert isinstance(token, str) -def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance,): +def test_cohere_async_stream_with_embed_english_v2_model( + cohere_instance, +): cohere_instance.model = "embed-english-v2.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance,): +def test_cohere_async_stream_with_embed_english_v3_model( + cohere_instance, +): cohere_instance.model = "embed-english-v3.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance,): +def test_cohere_async_stream_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,): +def test_cohere_async_stream_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: @@ -421,7 +479,9 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,): def test_cohere_representation_model_embedding(cohere_instance): # Test using the Representation model for text embedding cohere_instance.model = "embed-english-v3.0" - embedding = cohere_instance.embed("Generate an embedding for this text.") + embedding = cohere_instance.embed( + "Generate an embedding for this text." + ) assert isinstance(embedding, list) assert len(embedding) > 0 @@ -435,20 +495,26 @@ def test_cohere_representation_model_classification(cohere_instance): assert "score" in classification -def test_cohere_representation_model_language_detection(cohere_instance,): +def test_cohere_representation_model_language_detection( + cohere_instance, +): # Test using the Representation model for language detection cohere_instance.model = "embed-english-v3.0" language = cohere_instance.detect_language( - "Detect the language of this text.") + "Detect the language of this text." + ) assert isinstance(language, str) def test_cohere_representation_model_max_tokens_limit_exceeded( - cohere_instance,): + cohere_instance, +): # Test handling max tokens limit exceeded error cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 - prompt = ("This is a test prompt that will exceed the max tokens limit.") + prompt = ( + "This is a test prompt that will exceed the max tokens limit." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -456,80 +522,102 @@ def test_cohere_representation_model_max_tokens_limit_exceeded( # Add more production-grade test cases based on real-world scenarios -def test_cohere_representation_model_multilingual_embedding(cohere_instance,): +def test_cohere_representation_model_multilingual_embedding( + cohere_instance, +): # Test using the Representation model for multilingual text embedding cohere_instance.model = "embed-multilingual-v3.0" - embedding = cohere_instance.embed("Generate multilingual embeddings.") + embedding = cohere_instance.embed( + "Generate multilingual embeddings." + ) assert isinstance(embedding, list) assert len(embedding) > 0 def test_cohere_representation_model_multilingual_classification( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for multilingual text classification cohere_instance.model = "embed-multilingual-v3.0" - classification = cohere_instance.classify("Classify multilingual text.") + classification = cohere_instance.classify( + "Classify multilingual text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_multilingual_language_detection( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for multilingual language detection cohere_instance.model = "embed-multilingual-v3.0" language = cohere_instance.detect_language( - "Detect the language of multilingual text.") + "Detect the language of multilingual text." + ) assert isinstance(language, str) def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( - cohere_instance,): + cohere_instance, +): # Test handling max tokens limit exceeded error for multilingual model cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.max_tokens = 10 - prompt = ("This is a test prompt that will exceed the max tokens limit" - " for multilingual model.") + prompt = ( + "This is a test prompt that will exceed the max tokens limit" + " for multilingual model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) def test_cohere_representation_model_multilingual_light_embedding( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for multilingual light text embedding cohere_instance.model = "embed-multilingual-light-v3.0" - embedding = cohere_instance.embed("Generate multilingual light embeddings.") + embedding = cohere_instance.embed( + "Generate multilingual light embeddings." + ) assert isinstance(embedding, list) assert len(embedding) > 0 def test_cohere_representation_model_multilingual_light_classification( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for multilingual light text classification cohere_instance.model = "embed-multilingual-light-v3.0" classification = cohere_instance.classify( - "Classify multilingual light text.") + "Classify multilingual light text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_multilingual_light_language_detection( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for multilingual light language detection cohere_instance.model = "embed-multilingual-light-v3.0" language = cohere_instance.detect_language( - "Detect the language of multilingual light text.") + "Detect the language of multilingual light text." + ) assert isinstance(language, str) def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded( - cohere_instance,): + cohere_instance, +): # Test handling max tokens limit exceeded error for multilingual light model cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.max_tokens = 10 - prompt = ("This is a test prompt that will exceed the max tokens limit" - " for multilingual light model.") + prompt = ( + "This is a test prompt that will exceed the max tokens limit" + " for multilingual light model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -537,14 +625,18 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede def test_cohere_command_light_model(cohere_instance): # Test using the Command Light model for text generation cohere_instance.model = "command-light" - response = cohere_instance("Generate text using Command Light model.") + response = cohere_instance( + "Generate text using Command Light model." + ) assert isinstance(response, str) def test_cohere_base_light_model(cohere_instance): # Test using the Base Light model for text generation cohere_instance.model = "base-light" - response = cohere_instance("Generate text using Base Light model.") + response = cohere_instance( + "Generate text using Base Light model." + ) assert isinstance(response, str) @@ -555,7 +647,9 @@ def test_cohere_generate_summarize_endpoint(cohere_instance): assert isinstance(response, str) -def test_cohere_representation_model_english_embedding(cohere_instance,): +def test_cohere_representation_model_english_embedding( + cohere_instance, +): # Test using the Representation model for English text embedding cohere_instance.model = "embed-english-v3.0" embedding = cohere_instance.embed("Generate English embeddings.") @@ -563,69 +657,90 @@ def test_cohere_representation_model_english_embedding(cohere_instance,): assert len(embedding) > 0 -def test_cohere_representation_model_english_classification(cohere_instance,): +def test_cohere_representation_model_english_classification( + cohere_instance, +): # Test using the Representation model for English text classification cohere_instance.model = "embed-english-v3.0" - classification = cohere_instance.classify("Classify English text.") + classification = cohere_instance.classify( + "Classify English text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_english_language_detection( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for English language detection cohere_instance.model = "embed-english-v3.0" language = cohere_instance.detect_language( - "Detect the language of English text.") + "Detect the language of English text." + ) assert isinstance(language, str) def test_cohere_representation_model_english_max_tokens_limit_exceeded( - cohere_instance,): + cohere_instance, +): # Test handling max tokens limit exceeded error for English model cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 - prompt = ("This is a test prompt that will exceed the max tokens limit" - " for English model.") + prompt = ( + "This is a test prompt that will exceed the max tokens limit" + " for English model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) -def test_cohere_representation_model_english_light_embedding(cohere_instance,): +def test_cohere_representation_model_english_light_embedding( + cohere_instance, +): # Test using the Representation model for English light text embedding cohere_instance.model = "embed-english-light-v3.0" - embedding = cohere_instance.embed("Generate English light embeddings.") + embedding = cohere_instance.embed( + "Generate English light embeddings." + ) assert isinstance(embedding, list) assert len(embedding) > 0 def test_cohere_representation_model_english_light_classification( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for English light text classification cohere_instance.model = "embed-english-light-v3.0" - classification = cohere_instance.classify("Classify English light text.") + classification = cohere_instance.classify( + "Classify English light text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification def test_cohere_representation_model_english_light_language_detection( - cohere_instance,): + cohere_instance, +): # Test using the Representation model for English light language detection cohere_instance.model = "embed-english-light-v3.0" language = cohere_instance.detect_language( - "Detect the language of English light text.") + "Detect the language of English light text." + ) assert isinstance(language, str) def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( - cohere_instance,): + cohere_instance, +): # Test handling max tokens limit exceeded error for English light model cohere_instance.model = "embed-english-light-v3.0" cohere_instance.max_tokens = 10 - prompt = ("This is a test prompt that will exceed the max tokens limit" - " for English light model.") + prompt = ( + "This is a test prompt that will exceed the max tokens limit" + " for English light model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -633,7 +748,9 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( def test_cohere_command_model(cohere_instance): # Test using the Command model for text generation cohere_instance.model = "command" - response = cohere_instance("Generate text using the Command model.") + response = cohere_instance( + "Generate text using the Command model." + ) assert isinstance(response, str) @@ -647,7 +764,9 @@ def test_cohere_invalid_model(cohere_instance): cohere_instance("Generate text using an invalid model.") -def test_cohere_base_model_generation_with_max_tokens(cohere_instance,): +def test_cohere_base_model_generation_with_max_tokens( + cohere_instance, +): # Test generating text using the base model with a specified max_tokens limit cohere_instance.model = "base" cohere_instance.max_tokens = 20 diff --git a/tests/models/test_elevenlab.py b/tests/models/test_elevenlab.py index 11359f9a..da41ca53 100644 --- a/tests/models/test_elevenlab.py +++ b/tests/models/test_elevenlab.py @@ -30,30 +30,45 @@ def test_run_text_to_speech(eleven_labs_tool): def test_play_speech(eleven_labs_tool): - with patch("builtins.open", mock_open(read_data="fake_audio_data")): + with patch( + "builtins.open", mock_open(read_data="fake_audio_data") + ): eleven_labs_tool.play(EXPECTED_SPEECH_FILE) def test_stream_speech(eleven_labs_tool): - with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file: + with patch( + "tempfile.NamedTemporaryFile", mock_open() + ) as mock_file: eleven_labs_tool.stream_speech(SAMPLE_TEXT) - mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False) + mock_file.assert_called_with( + mode="bx", suffix=".wav", delete=False + ) # Testing fixture and environment variables def test_api_key_validation(eleven_labs_tool): - with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY): + with patch( + "langchain.utils.get_from_dict_or_env", return_value=API_KEY + ): values = {"eleven_api_key": None} - validated_values = eleven_labs_tool.validate_environment(values) + validated_values = eleven_labs_tool.validate_environment( + values + ) assert "eleven_api_key" in validated_values # Mocking the external library def test_run_text_to_speech_with_mock(eleven_labs_tool): - with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch( - "your_module._import_elevenlabs") as mock_elevenlabs: + with patch( + "tempfile.NamedTemporaryFile", mock_open() + ) as mock_file, patch( + "your_module._import_elevenlabs" + ) as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value - mock_elevenlabs_instance.generate.return_value = (b"fake_audio_data") + mock_elevenlabs_instance.generate.return_value = ( + b"fake_audio_data" + ) eleven_labs_tool.run(SAMPLE_TEXT) assert mock_file.call_args[1]["suffix"] == ".wav" assert mock_file.call_args[1]["delete"] is False @@ -65,11 +80,14 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): with patch("your_module._import_elevenlabs") as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value mock_elevenlabs_instance.generate.side_effect = Exception( - "Test Exception") + "Test Exception" + ) with pytest.raises( - RuntimeError, - match=("Error while running ElevenLabsText2SpeechTool: Test" - " Exception"), + RuntimeError, + match=( + "Error while running ElevenLabsText2SpeechTool: Test" + " Exception" + ), ): eleven_labs_tool.run(SAMPLE_TEXT) @@ -79,7 +97,9 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): "model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL], ) -def test_run_text_to_speech_with_different_models(eleven_labs_tool, model): +def test_run_text_to_speech_with_different_models( + eleven_labs_tool, model +): eleven_labs_tool.model = model speech_file = eleven_labs_tool.run(SAMPLE_TEXT) assert isinstance(speech_file, str) diff --git a/tests/models/test_fire_function_caller.py b/tests/models/test_fire_function_caller.py index 5e859272..082d954d 100644 --- a/tests/models/test_fire_function_caller.py +++ b/tests/models/test_fire_function_caller.py @@ -39,4 +39,6 @@ def test_fire_function_caller_run(mocker): tokenizer.batch_decode.assert_called_once_with(generated_ids) # Assert the decoded output is printed - assert decoded_output in mocker.patch.object(print, "call_args_list") + assert decoded_output in mocker.patch.object( + print, "call_args_list" + ) diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py index 79ce79e5..e76e11bb 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/test_fuyu.py @@ -38,7 +38,9 @@ def fuyu_instance(): # Test using the fixture. def test_fuyu_processor_initialization(fuyu_instance): assert isinstance(fuyu_instance.processor, FuyuProcessor) - assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) + assert isinstance( + fuyu_instance.image_processor, FuyuImageProcessor + ) # Test exception when providing an invalid image path. @@ -49,7 +51,6 @@ def test_invalid_image_path(fuyu_instance): # Using monkeypatch to replace the Image.open method to simulate a failure. def test_image_open_failure(fuyu_instance, monkeypatch): - def mock_open(*args, **kwargs): raise Exception("Mocked failure") @@ -78,9 +79,13 @@ def test_tokenizer_type(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance): - assert (fuyu_instance.processor.image_processor == - fuyu_instance.image_processor) - assert (fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer) + assert ( + fuyu_instance.processor.image_processor + == fuyu_instance.image_processor + ) + assert ( + fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer + ) def test_model_device_map(fuyu_instance): @@ -139,14 +144,22 @@ def test_get_img_invalid_path(fuyu_instance): # Test `run` method with valid inputs def test_run_valid_inputs(fuyu_instance): - with patch.object(fuyu_instance, "get_img") as mock_get_img, patch.object( - fuyu_instance, "processor") as mock_processor, patch.object( - fuyu_instance, "model") as mock_model: + with patch.object( + fuyu_instance, "get_img" + ) as mock_get_img, patch.object( + fuyu_instance, "processor" + ) as mock_processor, patch.object( + fuyu_instance, "model" + ) as mock_model: mock_get_img.return_value = "Test image" - mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} + mock_processor.return_value = { + "input_ids": torch.tensor([1, 2, 3]) + } mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test text"] - result = fuyu_instance.run("Hello, world!", "valid/path/to/image.png") + result = fuyu_instance.run( + "Hello, world!", "valid/path/to/image.png" + ) assert result == ["Test text"] @@ -173,7 +186,9 @@ def test_run_invalid_image_path(fuyu_instance): with patch.object(fuyu_instance, "get_img") as mock_get_img: mock_get_img.side_effect = FileNotFoundError with pytest.raises(FileNotFoundError): - fuyu_instance.run("Hello, world!", "invalid/path/to/image.png") + fuyu_instance.run( + "Hello, world!", "invalid/path/to/image.png" + ) # Test `__init__` method with default parameters diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index efe716ff..a61d1676 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -24,8 +24,12 @@ def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model): assert model.model is mock_genai_model -def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model): - model = Gemini(model_name="custom-model", gemini_api_key="custom-api-key") +def test_gemini_init_custom_params( + mock_gemini_api_key, mock_genai_model +): + model = Gemini( + model_name="custom-model", gemini_api_key="custom-api-key" + ) assert model.model_name == "custom-model" assert model.gemini_api_key == "custom-api-key" assert model.model is mock_genai_model @@ -50,13 +54,16 @@ def test_gemini_run_with_img( response = model.run(task=task, img=img) assert response == "Generated response" - mock_generate_content.assert_called_with(content=[task, "Processed image"]) + mock_generate_content.assert_called_with( + content=[task, "Processed image"] + ) mock_process_img.assert_called_with(img=img) @patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key, - mock_genai_model): +def test_gemini_run_without_img( + mock_generate_content, mock_gemini_api_key, mock_genai_model +): model = Gemini() task = "A cat" response_mock = Mock(text="Generated response") @@ -69,8 +76,9 @@ def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key, @patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_exception(mock_generate_content, mock_gemini_api_key, - mock_genai_model): +def test_gemini_run_exception( + mock_generate_content, mock_gemini_api_key, mock_genai_model +): model = Gemini() task = "A cat" mock_generate_content.side_effect = Exception("Test exception") @@ -88,23 +96,30 @@ def test_gemini_process_img(mock_gemini_api_key, mock_genai_model): with patch("builtins.open", create=True) as open_mock: open_mock.return_value.__enter__.return_value.read.return_value = ( - img_data) + img_data + ) processed_img = model.process_img(img) - assert processed_img == [{"mime_type": "image/png", "data": img_data}] + assert processed_img == [ + {"mime_type": "image/png", "data": img_data} + ] open_mock.assert_called_with(img, "rb") # Test Gemini initialization with missing API key def test_gemini_init_missing_api_key(): - with pytest.raises(ValueError, match="Please provide a Gemini API key"): + with pytest.raises( + ValueError, match="Please provide a Gemini API key" + ): Gemini(gemini_api_key=None) # Test Gemini initialization with missing model name def test_gemini_init_missing_model_name(): - with pytest.raises(ValueError, match="Please provide a model name"): + with pytest.raises( + ValueError, match="Please provide a model name" + ): Gemini(model_name=None) @@ -126,20 +141,26 @@ def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model): # Test Gemini process_img method with missing image -def test_gemini_process_img_missing_image(mock_gemini_api_key, - mock_genai_model): +def test_gemini_process_img_missing_image( + mock_gemini_api_key, mock_genai_model +): model = Gemini() img = None - with pytest.raises(ValueError, match="Please provide an image to process"): + with pytest.raises( + ValueError, match="Please provide an image to process" + ): model.process_img(img=img) # Test Gemini process_img method with missing image type -def test_gemini_process_img_missing_image_type(mock_gemini_api_key, - mock_genai_model): +def test_gemini_process_img_missing_image_type( + mock_gemini_api_key, mock_genai_model +): model = Gemini() img = "cat.png" - with pytest.raises(ValueError, match="Please provide the image type"): + with pytest.raises( + ValueError, match="Please provide the image type" + ): model.process_img(img=img, type=None) @@ -147,7 +168,9 @@ def test_gemini_process_img_missing_image_type(mock_gemini_api_key, def test_gemini_process_img_missing_api_key(mock_genai_model): model = Gemini(gemini_api_key=None) img = "cat.png" - with pytest.raises(ValueError, match="Please provide a Gemini API key"): + with pytest.raises( + ValueError, match="Please provide a Gemini API key" + ): model.process_img(img=img, type="image/png") @@ -170,7 +193,9 @@ def test_gemini_run_mock_img_processing( response = model.run(task=task, img=img) assert response == "Generated response" - mock_generate_content.assert_called_with(content=[task, "Processed image"]) + mock_generate_content.assert_called_with( + content=[task, "Processed image"] + ) mock_process_img.assert_called_with(img=img) diff --git a/tests/models/test_gigabind.py b/tests/models/test_gigabind.py index 493dbca3..3aae0739 100644 --- a/tests/models/test_gigabind.py +++ b/tests/models/test_gigabind.py @@ -11,13 +11,16 @@ except ImportError: @pytest.fixture def api(): - return Gigabind(host="localhost", port=8000, endpoint="embeddings") + return Gigabind( + host="localhost", port=8000, endpoint="embeddings" + ) @pytest.fixture def mock(requests_mock): - requests_mock.post("http://localhost:8000/embeddings", - json={"result": "success"}) + requests_mock.post( + "http://localhost:8000/embeddings", json={"result": "success"} + ) return requests_mock @@ -37,9 +40,9 @@ def test_run_with_audio(api, mock): def test_run_with_all(api, mock): - response = api.run(text="Hello, world!", - vision="image.jpg", - audio="audio.mp3") + response = api.run( + text="Hello, world!", vision="image.jpg", audio="audio.mp3" + ) assert response == {"result": "success"} @@ -62,20 +65,9 @@ def test_retry_on_failure(api, requests_mock): requests_mock.post( "http://localhost:8000/embeddings", [ - { - "status_code": 500, - "json": {} - }, - { - "status_code": 500, - "json": {} - }, - { - "status_code": 200, - "json": { - "result": "success" - } - }, + {"status_code": 500, "json": {}}, + {"status_code": 500, "json": {}}, + {"status_code": 200, "json": {"result": "success"}}, ], ) response = api.run(text="Hello, world!") @@ -86,18 +78,9 @@ def test_retry_exhausted(api, requests_mock): requests_mock.post( "http://localhost:8000/embeddings", [ - { - "status_code": 500, - "json": {} - }, - { - "status_code": 500, - "json": {} - }, - { - "status_code": 500, - "json": {} - }, + {"status_code": 500, "json": {}}, + {"status_code": 500, "json": {}}, + {"status_code": 500, "json": {}}, ], ) response = api.run(text="Hello, world!") @@ -110,7 +93,9 @@ def test_proxy_url(api): def test_invalid_response(api, requests_mock): - requests_mock.post("http://localhost:8000/embeddings", text="not json") + requests_mock.post( + "http://localhost:8000/embeddings", text="not json" + ) response = api.run(text="Hello, world!") assert response is None @@ -125,7 +110,9 @@ def test_connection_error(api, requests_mock): def test_http_error(api, requests_mock): - requests_mock.post("http://localhost:8000/embeddings", status_code=500) + requests_mock.post( + "http://localhost:8000/embeddings", status_code=500 + ) response = api.run(text="Hello, world!") assert response is None @@ -161,7 +148,9 @@ def test_run_with_large_all(api, mock): large_text = "Hello, world! " * 10000 # 10,000 repetitions large_vision = "image.jpg" * 10000 # 10,000 repetitions large_audio = "audio.mp3" * 10000 # 10,000 repetitions - response = api.run(text=large_text, vision=large_vision, audio=large_audio) + response = api.run( + text=large_text, vision=large_vision, audio=large_audio + ) assert response == {"result": "success"} diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py index 434ee502..ac797280 100644 --- a/tests/models/test_gpt4_vision_api.py +++ b/tests/models/test_gpt4_vision_api.py @@ -26,9 +26,9 @@ def test_init(vision_api): def test_encode_image(vision_api): with patch( - "builtins.open", - mock_open(read_data=b"test_image_data"), - create=True, + "builtins.open", + mock_open(read_data=b"test_image_data"), + create=True, ): encoded_image = vision_api.encode_image(img) assert encoded_image == "dGVzdF9pbWFnZV9kYXRh" @@ -37,8 +37,8 @@ def test_encode_image(vision_api): def test_run_success(vision_api): expected_response = {"This is the model's response."} with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: result = vision_api.run("What is this?", img) mock_post.assert_called_once() @@ -46,7 +46,9 @@ def test_run_success(vision_api): def test_run_request_error(vision_api): - with patch("requests.post", side_effect=RequestException("Request Error")): + with patch( + "requests.post", side_effect=RequestException("Request Error") + ): with pytest.raises(RequestException): vision_api.run("What is this?", img) @@ -54,18 +56,20 @@ def test_run_request_error(vision_api): def test_run_response_error(vision_api): expected_response = {"error": "Model Error"} with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ): with pytest.raises(RuntimeError): vision_api.run("What is this?", img) def test_call(vision_api): - expected_response = {"choices": [{"text": "This is the model's response."}]} + expected_response = { + "choices": [{"text": "This is the model's response."}] + } with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: result = vision_api("What is this?", img) mock_post.assert_called_once() @@ -91,7 +95,9 @@ def test_initialization_with_custom_key(): def test_run_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch("requests.post", side_effect=Exception("Test Exception")): + with patch( + "requests.post", side_effect=Exception("Test Exception") + ): with pytest.raises(Exception): gpt_api.run(task, img_url) @@ -99,10 +105,14 @@ def test_run_with_exception(gpt_api): def test_call_method_successful_response(gpt_api): task = "What is in the image?" img_url = img - response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]} + response_json = { + "choices": [{"text": "Answer from GPT-4 Vision"}] + } mock_response = Mock() mock_response.json.return_value = response_json - with patch("requests.post", return_value=mock_response) as mock_post: + with patch( + "requests.post", return_value=mock_response + ) as mock_post: result = gpt_api(task, img_url) mock_post.assert_called_once() assert result == response_json @@ -111,7 +121,9 @@ def test_call_method_successful_response(gpt_api): def test_call_method_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch("requests.post", side_effect=Exception("Test Exception")): + with patch( + "requests.post", side_effect=Exception("Test Exception") + ): with pytest.raises(Exception): gpt_api(task, img_url) @@ -119,17 +131,16 @@ def test_call_method_with_exception(gpt_api): @pytest.mark.asyncio async def test_arun_success(vision_api): expected_response = { - "choices": [{ - "message": { - "content": "This is the model's response." - } - }] + "choices": [ + {"message": {"content": "This is the model's response."}} + ] } with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock( - return_value=expected_response)), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock( + json=AsyncMock(return_value=expected_response) + ), ) as mock_post: result = await vision_api.arun("What is this?", img) mock_post.assert_called_once() @@ -139,9 +150,9 @@ async def test_arun_success(vision_api): @pytest.mark.asyncio async def test_arun_request_error(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=Exception("Request Error"), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + side_effect=Exception("Request Error"), ): with pytest.raises(Exception): await vision_api.arun("What is this?", img) @@ -149,15 +160,13 @@ async def test_arun_request_error(vision_api): def test_run_many_success(vision_api): expected_response = { - "choices": [{ - "message": { - "content": "This is the model's response." - } - }] + "choices": [ + {"message": {"content": "This is the model's response."}} + ] } with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: tasks = ["What is this?", "What is that?"] imgs = [img, img] @@ -170,7 +179,9 @@ def test_run_many_success(vision_api): def test_run_many_request_error(vision_api): - with patch("requests.post", side_effect=RequestException("Request Error")): + with patch( + "requests.post", side_effect=RequestException("Request Error") + ): tasks = ["What is this?", "What is that?"] imgs = [img, img] with pytest.raises(RequestException): @@ -180,9 +191,11 @@ def test_run_many_request_error(vision_api): @pytest.mark.asyncio async def test_arun_json_decode_error(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock(side_effect=ValueError)), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock( + json=AsyncMock(side_effect=ValueError) + ), ): with pytest.raises(ValueError): await vision_api.arun("What is this?", img) @@ -192,9 +205,11 @@ async def test_arun_json_decode_error(vision_api): async def test_arun_api_error(vision_api): error_response = {"error": {"message": "API Error"}} with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock(return_value=error_response)), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock( + json=AsyncMock(return_value=error_response) + ), ): with pytest.raises(Exception, match="API Error"): await vision_api.arun("What is this?", img) @@ -204,10 +219,11 @@ async def test_arun_api_error(vision_api): async def test_arun_unexpected_response(vision_api): unexpected_response = {"unexpected": "response"} with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock( - return_value=unexpected_response)), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + return_value=AsyncMock( + json=AsyncMock(return_value=unexpected_response) + ), ): with pytest.raises(Exception, match="Unexpected response"): await vision_api.arun("What is this?", img) @@ -216,9 +232,9 @@ async def test_arun_unexpected_response(vision_api): @pytest.mark.asyncio async def test_arun_retries(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=ClientResponseError(None, None), + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + side_effect=ClientResponseError(None, None), ) as mock_post: with pytest.raises(ClientResponseError): await vision_api.arun("What is this?", img) @@ -228,9 +244,9 @@ async def test_arun_retries(vision_api): @pytest.mark.asyncio async def test_arun_timeout(vision_api): with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=asyncio.TimeoutError, + "aiohttp.ClientSession.post", + new_callable=AsyncMock, + side_effect=asyncio.TimeoutError, ): with pytest.raises(asyncio.TimeoutError): await vision_api.arun("What is this?", img) diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py index 31fc41d4..cbbba940 100644 --- a/tests/models/test_hf.py +++ b/tests/models/test_hf.py @@ -133,7 +133,9 @@ def test_llm_set_repitition_penalty(llm_instance): def test_llm_set_no_repeat_ngram_size(llm_instance): new_no_repeat_ngram_size = 6 llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size) - assert (llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size) + assert ( + llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size + ) # Test for setting temperature @@ -183,7 +185,9 @@ def test_llm_set_model_id(llm_instance): # Test for setting model -@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained") +@patch( + "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" +) def test_llm_set_model(mock_model, llm_instance): mock_model.return_value = "mocked model" llm_instance.set_model(mock_model) diff --git a/tests/models/test_hf_pipeline.py b/tests/models/test_hf_pipeline.py index ec306797..8580dd56 100644 --- a/tests/models/test_hf_pipeline.py +++ b/tests/models/test_hf_pipeline.py @@ -14,14 +14,19 @@ def mock_pipeline(): @pytest.fixture def pipeline(mock_pipeline): - return HuggingfacePipeline("text-generation", - "meta-llama/Llama-2-13b-chat-hf") + return HuggingfacePipeline( + "text-generation", "meta-llama/Llama-2-13b-chat-hf" + ) def test_init(pipeline, mock_pipeline): assert pipeline.task_type == "text-generation" assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf" - assert (pipeline.use_fp8 is True if torch.cuda.is_available() else False) + assert ( + pipeline.use_fp8 is True + if torch.cuda.is_available() + else False + ) mock_pipeline.assert_called_once_with( "text-generation", "meta-llama/Llama-2-13b-chat-hf", @@ -46,5 +51,6 @@ def test_run_with_different_task(pipeline, mock_pipeline): mock_pipeline.return_value = "Generated text" result = pipeline.run("text-classification", "Hello, world!") assert result == "Generated text" - mock_pipeline.assert_called_once_with("text-classification", - "Hello, world!") + mock_pipeline.assert_called_once_with( + "text-classification", "Hello, world!" + ) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 5848e42f..7e19a056 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -18,7 +18,10 @@ def llm_instance(): # Test for instantiation and attributes def test_llm_initialization(llm_instance): - assert (llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha") + assert ( + llm_instance.model_id + == "NousResearch/Nous-Hermes-2-Vision-Alpha" + ) assert llm_instance.max_length == 500 # ... add more assertions for all default attributes @@ -85,12 +88,15 @@ def test_llm_memory_consumption(llm_instance): ) def test_llm_initialization_params(model_id, max_length): if max_length: - instance = HuggingfaceLLM(model_id=model_id, max_length=max_length) + instance = HuggingfaceLLM( + model_id=model_id, max_length=max_length + ) assert instance.max_length == max_length else: instance = HuggingfaceLLM(model_id=model_id) - assert (instance.max_length == 500 - ) # Assuming 500 is the default max_length + assert ( + instance.max_length == 500 + ) # Assuming 500 is the default max_length # Test for setting an invalid device @@ -138,7 +144,9 @@ def test_llm_run_output_length(mock_run, llm_instance): # Test the tokenizer handling special tokens correctly @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode") -def test_llm_tokenizer_special_tokens(mock_decode, mock_encode, llm_instance): +def test_llm_tokenizer_special_tokens( + mock_decode, mock_encode, llm_instance +): mock_encode.return_value = "encoded input with special tokens" mock_decode.return_value = "decoded output with special tokens" result = llm_instance.run("test task with special tokens") @@ -164,8 +172,9 @@ def test_llm_response_time(mock_run, llm_instance): start_time = time.time() llm_instance.run("test task for response time") end_time = time.time() - assert (end_time - start_time - < 1) # Assuming the response should be faster than 1 second + assert ( + end_time - start_time < 1 + ) # Assuming the response should be faster than 1 second # Test the logging of a warning for long inputs @@ -188,10 +197,13 @@ def test_llm_run_model_exception(mock_generate, llm_instance): # Test the behavior when GPU is forced but not available @patch("torch.cuda.is_available", return_value=False) -def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): +def test_llm_force_gpu_when_unavailable( + mock_is_available, llm_instance +): with pytest.raises(EnvironmentError): llm_instance.set_device( - "cuda") # Attempt to set CUDA when it's not available + "cuda" + ) # Attempt to set CUDA when it's not available # Test for proper cleanup after model use (releasing resources) @@ -209,8 +221,9 @@ def test_llm_multilingual_input(mock_run, llm_instance): mock_run.return_value = "mocked multilingual output" multilingual_input = "Bonjour, ceci est un test multilingue." result = llm_instance.run(multilingual_input) - assert isinstance(result, - str) # Simple check to ensure output is string type + assert isinstance( + result, str + ) # Simple check to ensure output is string type # Test caching mechanism to prevent re-running the same inputs diff --git a/tests/models/test_idefics.py b/tests/models/test_idefics.py index cde670f5..3bfee679 100644 --- a/tests/models/test_idefics.py +++ b/tests/models/test_idefics.py @@ -13,8 +13,8 @@ from swarms.models.idefics import ( @pytest.fixture def idefics_instance(): with patch( - "torch.cuda.is_available", - return_value=False): # Assuming tests are run on CPU for simplicity + "torch.cuda.is_available", return_value=False + ): # Assuming tests are run on CPU for simplicity instance = Idefics() return instance @@ -36,8 +36,8 @@ def test_init_default(idefics_instance): ) def test_init_device(device, expected): with patch( - "torch.cuda.is_available", - return_value=True if expected == "cuda" else False, + "torch.cuda.is_available", + return_value=True if expected == "cuda" else False, ): instance = Idefics(device=device) assert instance.device == expected @@ -46,10 +46,14 @@ def test_init_device(device, expected): # Test `run` method def test_run(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, - "processor") as mock_processor, patch.object( - idefics_instance, "model") as mock_model: - mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: + mock_processor.return_value = { + "input_ids": torch.tensor([1, 2, 3]) + } mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -61,10 +65,14 @@ def test_run(idefics_instance): # Test `__call__` method (using the same logic as run for simplicity) def test_call(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, - "processor") as mock_processor, patch.object( - idefics_instance, "model") as mock_model: - mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: + mock_processor.return_value = { + "input_ids": torch.tensor([1, 2, 3]) + } mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -77,7 +85,9 @@ def test_call(idefics_instance): def test_chat(idefics_instance): user_input = "User: Hello" response = "Model: Hi there!" - with patch.object(idefics_instance, "run", return_value=[response]): + with patch.object( + idefics_instance, "run", return_value=[response] + ): result = idefics_instance.chat(user_input) assert result == response @@ -87,13 +97,16 @@ def test_chat(idefics_instance): # Test `set_checkpoint` method def test_set_checkpoint(idefics_instance): new_checkpoint = "new_checkpoint" - with patch.object(IdeficsForVisionText2Text, - "from_pretrained") as mock_from_pretrained, patch.object( - AutoProcessor, "from_pretrained"): + with patch.object( + IdeficsForVisionText2Text, "from_pretrained" + ) as mock_from_pretrained, patch.object( + AutoProcessor, "from_pretrained" + ): idefics_instance.set_checkpoint(new_checkpoint) - mock_from_pretrained.assert_called_with(new_checkpoint, - torch_dtype=torch.bfloat16) + mock_from_pretrained.assert_called_with( + new_checkpoint, torch_dtype=torch.bfloat16 + ) # Test `set_device` method @@ -122,7 +135,7 @@ def test_clear_chat_history(idefics_instance): # Exception Tests def test_run_with_empty_prompts(idefics_instance): with pytest.raises( - Exception + Exception ): # Replace Exception with the actual exception that may arise for an empty prompt. idefics_instance.run([]) @@ -130,10 +143,14 @@ def test_run_with_empty_prompts(idefics_instance): # Test `run` method with batched_mode set to False def test_run_batched_mode_false(idefics_instance): task = "User: Test" - with patch.object(idefics_instance, - "processor") as mock_processor, patch.object( - idefics_instance, "model") as mock_model: - mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: + mock_processor.return_value = { + "input_ids": torch.tensor([1, 2, 3]) + } mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -146,7 +163,9 @@ def test_run_batched_mode_false(idefics_instance): # Test `run` method with an exception def test_run_with_exception(idefics_instance): task = "User: Test" - with patch.object(idefics_instance, "processor") as mock_processor: + with patch.object( + idefics_instance, "processor" + ) as mock_processor: mock_processor.side_effect = Exception("Test exception") with pytest.raises(Exception): idefics_instance.run(task) @@ -155,21 +174,24 @@ def test_run_with_exception(idefics_instance): # Test `set_model_name` method def test_set_model_name(idefics_instance): new_model_name = "new_model_name" - with patch.object(IdeficsForVisionText2Text, - "from_pretrained") as mock_from_pretrained, patch.object( - AutoProcessor, "from_pretrained"): + with patch.object( + IdeficsForVisionText2Text, "from_pretrained" + ) as mock_from_pretrained, patch.object( + AutoProcessor, "from_pretrained" + ): idefics_instance.set_model_name(new_model_name) assert idefics_instance.model_name == new_model_name - mock_from_pretrained.assert_called_with(new_model_name, - torch_dtype=torch.bfloat16) + mock_from_pretrained.assert_called_with( + new_model_name, torch_dtype=torch.bfloat16 + ) # Test `__init__` method with device set to None def test_init_device_none(): with patch( - "torch.cuda.is_available", - return_value=False, + "torch.cuda.is_available", + return_value=False, ): instance = Idefics(device=None) assert instance.device == "cpu" @@ -178,8 +200,8 @@ def test_init_device_none(): # Test `__init__` method with device set to "cuda" def test_init_device_cuda(): with patch( - "torch.cuda.is_available", - return_value=True, + "torch.cuda.is_available", + return_value=True, ): instance = Idefics(device="cuda") assert instance.device == "cuda" diff --git a/tests/models/test_kosmos.py b/tests/models/test_kosmos.py index ad3718ca..1219f895 100644 --- a/tests/models/test_kosmos.py +++ b/tests/models/test_kosmos.py @@ -16,7 +16,9 @@ def mock_image_request(): img_data = open(TEST_IMAGE_URL, "rb").read() mock_resp = Mock() mock_resp.raw = img_data - with patch.object(requests, "get", return_value=mock_resp) as _fixture: + with patch.object( + requests, "get", return_value=mock_resp + ) as _fixture: yield _fixture @@ -45,16 +47,18 @@ def test_get_image(mock_image_request): # Test multimodal grounding def test_multimodal_grounding(mock_image_request): kosmos = Kosmos() - kosmos.multimodal_grounding("Find the red apple in the image.", - TEST_IMAGE_URL) + kosmos.multimodal_grounding( + "Find the red apple in the image.", TEST_IMAGE_URL + ) # TODO: Validate the result if possible # Test referring expression comprehension def test_referring_expression_comprehension(mock_image_request): kosmos = Kosmos() - kosmos.referring_expression_comprehension("Show me the green bottle.", - TEST_IMAGE_URL) + kosmos.referring_expression_comprehension( + "Show me the green bottle.", TEST_IMAGE_URL + ) # TODO: Validate the result if possible @@ -89,7 +93,6 @@ IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=fo # Mock response for requests.get() class MockResponse: - @staticmethod def json(): return {} @@ -108,23 +111,30 @@ def kosmos(): # Mocking the requests.get() method @pytest.fixture def mock_request_get(monkeypatch): - monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse()) + monkeypatch.setattr( + requests, "get", lambda url, **kwargs: MockResponse() + ) @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding(kosmos): - kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1) + kosmos.multimodal_grounding( + "Find the red apple in the image.", IMG_URL1 + ) @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension(kosmos): - kosmos.referring_expression_comprehension("Show me the green bottle.", - IMG_URL2) + kosmos.referring_expression_comprehension( + "Show me the green bottle.", IMG_URL2 + ) @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_generation(kosmos): - kosmos.referring_expression_generation("It is on the table.", IMG_URL3) + kosmos.referring_expression_generation( + "It is on the table.", IMG_URL3 + ) @pytest.mark.usefixtures("mock_request_get") @@ -144,13 +154,16 @@ def test_grounded_image_captioning_detailed(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding_2(kosmos): - kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2) + kosmos.multimodal_grounding( + "Find the yellow fruit in the image.", IMG_URL2 + ) @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension_2(kosmos): - kosmos.referring_expression_comprehension("Where is the water bottle?", - IMG_URL3) + kosmos.referring_expression_comprehension( + "Where is the water bottle?", IMG_URL3 + ) @pytest.mark.usefixtures("mock_request_get") diff --git a/tests/models/test_llama_function_caller.py b/tests/models/test_llama_function_caller.py index f7afb90c..1e9df654 100644 --- a/tests/models/test_llama_function_caller.py +++ b/tests/models/test_llama_function_caller.py @@ -18,7 +18,6 @@ def test_llama_model_loading(llama_caller): # Test adding and calling custom functions def test_llama_custom_function(llama_caller): - def sample_function(arg1, arg2): return f"Sample function called with args: {arg1}, {arg2}" @@ -40,11 +39,13 @@ def test_llama_custom_function(llama_caller): ], ) - result = llama_caller.call_function("sample_function", - arg1="arg1_value", - arg2="arg2_value") + result = llama_caller.call_function( + "sample_function", arg1="arg1_value", arg2="arg2_value" + ) assert ( - result == "Sample function called with args: arg1_value, arg2_value") + result + == "Sample function called with args: arg1_value, arg2_value" + ) # Test streaming user prompts @@ -63,7 +64,6 @@ def test_llama_custom_function_not_found(llama_caller): # Test invalid arguments for custom function def test_llama_custom_function_invalid_arguments(llama_caller): - def sample_function(arg1, arg2): return f"Sample function called with args: {arg1}, {arg2}" @@ -86,7 +86,9 @@ def test_llama_custom_function_invalid_arguments(llama_caller): ) with pytest.raises(TypeError): - llama_caller.call_function("sample_function", arg1="arg1_value") + llama_caller.call_function( + "sample_function", arg1="arg1_value" + ) # Test streaming with custom runtime diff --git a/tests/models/test_mixtral.py b/tests/models/test_mixtral.py index ce26a777..a68a9026 100644 --- a/tests/models/test_mixtral.py +++ b/tests/models/test_mixtral.py @@ -21,18 +21,22 @@ def test_mixtral_run(mock_model, mock_tokenizer): mixtral = Mixtral() mock_tokenizer_instance = MagicMock() mock_model_instance = MagicMock() - mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance) + mock_tokenizer.from_pretrained.return_value = ( + mock_tokenizer_instance + ) mock_model.from_pretrained.return_value = mock_model_instance mock_tokenizer_instance.return_tensors = "pt" mock_model_instance.generate.return_value = [101, 102, 103] mock_tokenizer_instance.decode.return_value = "Generated text" result = mixtral.run("Test task") assert result == "Generated text" - mock_tokenizer_instance.assert_called_once_with("Test task", - return_tensors="pt") + mock_tokenizer_instance.assert_called_once_with( + "Test task", return_tensors="pt" + ) mock_model_instance.generate.assert_called_once() mock_tokenizer_instance.decode.assert_called_once_with( - [101, 102, 103], skip_special_tokens=True) + [101, 102, 103], skip_special_tokens=True + ) @patch("swarms.models.mixtral.AutoTokenizer") @@ -41,7 +45,9 @@ def test_mixtral_run_error(mock_model, mock_tokenizer): mixtral = Mixtral() mock_tokenizer_instance = MagicMock() mock_model_instance = MagicMock() - mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance) + mock_tokenizer.from_pretrained.return_value = ( + mock_tokenizer_instance + ) mock_model.from_pretrained.return_value = mock_model_instance mock_tokenizer_instance.return_tensors = "pt" mock_model_instance.generate.side_effect = Exception("Test error") diff --git a/tests/models/test_mpt7b.py b/tests/models/test_mpt7b.py index e0b49624..92b6c254 100644 --- a/tests/models/test_mpt7b.py +++ b/tests/models/test_mpt7b.py @@ -25,10 +25,14 @@ def test_mpt7b_run(): "EleutherAI/gpt-neox-20b", max_tokens=150, ) - output = mpt.run("generate", "Once upon a time in a land far, far away...") + output = mpt.run( + "generate", "Once upon a time in a land far, far away..." + ) assert isinstance(output, str) - assert output.startswith("Once upon a time in a land far, far away...") + assert output.startswith( + "Once upon a time in a land far, far away..." + ) def test_mpt7b_run_invalid_task(): @@ -51,10 +55,14 @@ def test_mpt7b_generate(): "EleutherAI/gpt-neox-20b", max_tokens=150, ) - output = mpt.generate("Once upon a time in a land far, far away...") + output = mpt.generate( + "Once upon a time in a land far, far away..." + ) assert isinstance(output, str) - assert output.startswith("Once upon a time in a land far, far away...") + assert output.startswith( + "Once upon a time in a land far, far away..." + ) def test_mpt7b_batch_generate(): diff --git a/tests/models/test_nougat.py b/tests/models/test_nougat.py index 2790fd77..858845a6 100644 --- a/tests/models/test_nougat.py +++ b/tests/models/test_nougat.py @@ -43,7 +43,9 @@ def test_model_initialization(setup_nougat): "cuda_available, expected_device", [(True, "cuda"), (False, "cpu")], ) -def test_device_initialization(cuda_available, expected_device, monkeypatch): +def test_device_initialization( + cuda_available, expected_device, monkeypatch +): monkeypatch.setattr( torch, "cuda", @@ -72,7 +74,9 @@ def test_get_image_invalid_path(setup_nougat): (10, 50), ], ) -def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens): +def test_model_call_with_diff_params( + setup_nougat, min_len, max_tokens +): setup_nougat.min_length = min_len setup_nougat.max_new_tokens = max_tokens @@ -103,11 +107,11 @@ def test_model_call_mocked_output(setup_nougat): def mock_processor_and_model(): """Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior.""" with patch( - "transformers.NougatProcessor.from_pretrained", - return_value=Mock(), + "transformers.NougatProcessor.from_pretrained", + return_value=Mock(), ), patch( - "transformers.VisionEncoderDecoderModel.from_pretrained", - return_value=Mock(), + "transformers.VisionEncoderDecoderModel.from_pretrained", + return_value=Mock(), ): yield @@ -118,7 +122,8 @@ def test_nougat_with_sample_image_1(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - )) + ) + ) assert isinstance(result, str) @@ -135,7 +140,8 @@ def test_nougat_min_length_param(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - )) + ) + ) assert isinstance(result, str) @@ -146,7 +152,8 @@ def test_nougat_max_new_tokens_param(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - )) + ) + ) assert isinstance(result, str) @@ -157,13 +164,16 @@ def test_nougat_different_model_path(setup_nougat): os.path.join( "sample_images", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - )) + ) + ) assert isinstance(result, str) @pytest.mark.usefixtures("mock_processor_and_model") def test_nougat_bad_image_path(setup_nougat): - with pytest.raises(Exception): # Adjust the exception type accordingly. + with pytest.raises( + Exception + ): # Adjust the exception type accordingly. setup_nougat("bad_image_path.png") @@ -173,7 +183,8 @@ def test_nougat_image_large_size(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", - )) + ) + ) assert isinstance(result, str) @@ -183,7 +194,8 @@ def test_nougat_image_small_size(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", - )) + ) + ) assert isinstance(result, str) @@ -193,7 +205,8 @@ def test_nougat_image_varied_content(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", - )) + ) + ) assert isinstance(result, str) @@ -203,5 +216,6 @@ def test_nougat_image_with_metadata(setup_nougat): os.path.join( "sample_images", "https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", - )) + ) + ) assert isinstance(result, str) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 1c6442b5..a920256c 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -4,16 +4,18 @@ from swarms.models.qwen import QwenVLMultiModal def test_post_init(): - with patch("swarms.models.qwen.AutoTokenizer.from_pretrained" - ) as mock_tokenizer, patch( - "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" - ) as mock_model: + with patch( + "swarms.models.qwen.AutoTokenizer.from_pretrained" + ) as mock_tokenizer, patch( + "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" + ) as mock_model: mock_tokenizer.return_value = Mock() mock_model.return_value = Mock() model = QwenVLMultiModal() - mock_tokenizer.assert_called_once_with(model.model_name, - trust_remote_code=True) + mock_tokenizer.assert_called_once_with( + model.model_name, trust_remote_code=True + ) mock_model.assert_called_once_with( model.model_name, device_map=model.device, @@ -23,31 +25,37 @@ def test_post_init(): def test_run(): with patch( - "swarms.models.qwen.AutoTokenizer.from_list_format" + "swarms.models.qwen.AutoTokenizer.from_list_format" ) as mock_format, patch( - "swarms.models.qwen.AutoTokenizer.__call__") as mock_call, patch( - "swarms.models.qwen.AutoModelForCausalLM.generate" - ) as mock_generate, patch( - "swarms.models.qwen.AutoTokenizer.decode") as mock_decode: + "swarms.models.qwen.AutoTokenizer.__call__" + ) as mock_call, patch( + "swarms.models.qwen.AutoModelForCausalLM.generate" + ) as mock_generate, patch( + "swarms.models.qwen.AutoTokenizer.decode" + ) as mock_decode: mock_format.return_value = Mock() mock_call.return_value = Mock() mock_generate.return_value = Mock() mock_decode.return_value = "response" model = QwenVLMultiModal() - response = model.run("Hello, how are you?", - "https://example.com/image.jpg") + response = model.run( + "Hello, how are you?", "https://example.com/image.jpg" + ) assert response == "response" def test_chat(): - with patch("swarms.models.qwen.AutoModelForCausalLM.chat") as mock_chat: + with patch( + "swarms.models.qwen.AutoModelForCausalLM.chat" + ) as mock_chat: mock_chat.return_value = ("response", ["history"]) model = QwenVLMultiModal() - response, history = model.chat("Hello, how are you?", - "https://example.com/image.jpg") + response, history = model.chat( + "Hello, how are you?", "https://example.com/image.jpg" + ) assert response == "response" assert history == ["history"] diff --git a/tests/models/test_speech_t5.py b/tests/models/test_speech_t5.py index 0fc61e6b..d32c21db 100644 --- a/tests/models/test_speech_t5.py +++ b/tests/models/test_speech_t5.py @@ -16,11 +16,16 @@ def speecht5_model(): def test_speecht5_init(speecht5_model): - assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) + assert isinstance( + speecht5_model.processor, SpeechT5.processor.__class__ + ) assert isinstance(speecht5_model.model, SpeechT5.model.__class__) - assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) - assert isinstance(speecht5_model.embeddings_dataset, - torch.utils.data.Dataset) + assert isinstance( + speecht5_model.vocoder, SpeechT5.vocoder.__class__ + ) + assert isinstance( + speecht5_model.embeddings_dataset, torch.utils.data.Dataset + ) def test_speecht5_call(speecht5_model): @@ -44,7 +49,10 @@ def test_speecht5_set_model(speecht5_model): speecht5_model.set_model(new_model_name) assert speecht5_model.model_name == new_model_name assert speecht5_model.processor.model_name == new_model_name - assert (speecht5_model.model.config.model_name_or_path == new_model_name) + assert ( + speecht5_model.model.config.model_name_or_path + == new_model_name + ) speecht5_model.set_model(old_model_name) # Restore original model @@ -54,8 +62,12 @@ def test_speecht5_set_vocoder(speecht5_model): speecht5_model.set_vocoder(new_vocoder_name) assert speecht5_model.vocoder_name == new_vocoder_name assert ( - speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name) - speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder + speecht5_model.vocoder.config.model_name_or_path + == new_vocoder_name + ) + speecht5_model.set_vocoder( + old_vocoder_name + ) # Restore original vocoder def test_speecht5_set_embeddings_dataset(speecht5_model): @@ -63,10 +75,12 @@ def test_speecht5_set_embeddings_dataset(speecht5_model): new_dataset_name = "Matthijs/cmu-arctic-xvectors-test" speecht5_model.set_embeddings_dataset(new_dataset_name) assert speecht5_model.dataset_name == new_dataset_name - assert isinstance(speecht5_model.embeddings_dataset, - torch.utils.data.Dataset) + assert isinstance( + speecht5_model.embeddings_dataset, torch.utils.data.Dataset + ) speecht5_model.set_embeddings_dataset( - old_dataset_name) # Restore original dataset + old_dataset_name + ) # Restore original dataset def test_speecht5_get_sampling_rate(speecht5_model): @@ -98,7 +112,9 @@ def test_speecht5_change_dataset_split(speecht5_model): def test_speecht5_load_custom_embedding(speecht5_model): xvector = [0.1, 0.2, 0.3, 0.4, 0.5] embedding = speecht5_model.load_custom_embedding(xvector) - assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))) + assert torch.all( + torch.eq(embedding, torch.tensor(xvector).unsqueeze(0)) + ) def test_speecht5_with_different_speakers(speecht5_model): @@ -109,7 +125,9 @@ def test_speecht5_with_different_speakers(speecht5_model): assert isinstance(speech, torch.Tensor) -def test_speecht5_save_speech_with_different_extensions(speecht5_model,): +def test_speecht5_save_speech_with_different_extensions( + speecht5_model, +): text = "Hello, how are you?" speech = speecht5_model(text) extensions = [".wav", ".flac"] @@ -144,4 +162,6 @@ def test_speecht5_change_vocoder_model(speecht5_model): speecht5_model.set_vocoder(new_vocoder_name) speech = speecht5_model(text) assert isinstance(speech, torch.Tensor) - speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder + speecht5_model.set_vocoder( + old_vocoder_name + ) # Restore original vocoder diff --git a/tests/models/test_ssd_1b.py b/tests/models/test_ssd_1b.py index b9b5bb25..f658f853 100644 --- a/tests/models/test_ssd_1b.py +++ b/tests/models/test_ssd_1b.py @@ -21,30 +21,36 @@ def test_ssd1b_call(ssd1b_model): image_url = ssd1b_model(task, neg_prompt) assert isinstance(image_url, str) assert image_url.startswith( - "https://") # Assuming it starts with "https://" + "https://" + ) # Assuming it starts with "https://" # Add more tests for various aspects of the class and methods # Example of a parameterized test for different tasks -@pytest.mark.parametrize("task", - ["A painting of a cat", "A painting of a tree"]) +@pytest.mark.parametrize( + "task", ["A painting of a cat", "A painting of a tree"] +) def test_ssd1b_parameterized_task(ssd1b_model, task): image_url = ssd1b_model(task) assert isinstance(image_url, str) assert image_url.startswith( - "https://") # Assuming it starts with "https://" + "https://" + ) # Assuming it starts with "https://" # Example of a test using mocks to isolate units of code def test_ssd1b_with_mock(ssd1b_model, mocker): - mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline + mocker.patch( + "your_module.StableDiffusionXLPipeline" + ) # Mock the pipeline task = "A painting of a cat" image_url = ssd1b_model(task) assert isinstance(image_url, str) assert image_url.startswith( - "https://") # Assuming it starts with "https://" + "https://" + ) # Assuming it starts with "https://" def test_ssd1b_call_with_cache(ssd1b_model): @@ -62,8 +68,9 @@ def test_ssd1b_invalid_task(ssd1b_model): def test_ssd1b_failed_api_call(ssd1b_model, mocker): - mocker.patch("your_module.StableDiffusionXLPipeline" - ) # Mock the pipeline to raise an exception + mocker.patch( + "your_module.StableDiffusionXLPipeline" + ) # Mock the pipeline to raise an exception task = "A painting of a cat" with pytest.raises(Exception): ssd1b_model(task) diff --git a/tests/models/test_timm.py b/tests/models/test_timm.py index f4dc8dc2..4af689e5 100644 --- a/tests/models/test_timm.py +++ b/tests/models/test_timm.py @@ -19,16 +19,18 @@ def test_timm_model_init(): def test_timm_model_call(): - with patch("swarms.models.timm.create_model") as mock_create_model: + with patch( + "swarms.models.timm.create_model" + ) as mock_create_model: model_name = "resnet18" pretrained = True in_chans = 3 timm_model = TimmModel(model_name, pretrained, in_chans) task = torch.rand(1, in_chans, 224, 224) result = timm_model(task) - mock_create_model.assert_called_once_with(model_name, - pretrained=pretrained, - in_chans=in_chans) + mock_create_model.assert_called_once_with( + model_name, pretrained=pretrained, in_chans=in_chans + ) assert result == mock_create_model.return_value(task) diff --git a/tests/models/test_timm_model.py b/tests/models/test_timm_model.py index 06ec2a93..b2f8f6c9 100644 --- a/tests/models/test_timm_model.py +++ b/tests/models/test_timm_model.py @@ -22,14 +22,17 @@ def test_create_model(sample_model_info): def test_call(sample_model_info): model_handler = TimmModel() input_tensor = torch.randn(1, 3, 224, 224) - output_shape = model_handler.__call__(sample_model_info, input_tensor) + output_shape = model_handler.__call__( + sample_model_info, input_tensor + ) assert isinstance(output_shape, torch.Size) def test_get_supported_models_mock(): model_handler = TimmModel() model_handler._get_supported_models = Mock( - return_value=["resnet18", "resnet50"]) + return_value=["resnet18", "resnet50"] + ) supported_models = model_handler._get_supported_models() assert supported_models == ["resnet18", "resnet50"] diff --git a/tests/models/test_togther.py b/tests/models/test_togther.py index 4b14acc3..dd2a2f89 100644 --- a/tests/models/test_togther.py +++ b/tests/models/test_togther.py @@ -55,11 +55,7 @@ def test_init_custom_params(mock_api_key): def test_run_success(mock_post, mock_api_key): mock_response = Mock() mock_response.json.return_value = { - "choices": [{ - "message": { - "content": "Generated response" - } - }] + "choices": [{"message": {"content": "Generated response"}}] } mock_post.return_value = mock_response @@ -73,7 +69,8 @@ def test_run_success(mock_post, mock_api_key): @patch("swarms.models.together_model.requests.post") def test_run_failure(mock_post, mock_api_key): mock_post.side_effect = requests.exceptions.RequestException( - "Request failed") + "Request failed" + ) model = TogetherLLM() task = "What is the color of the object?" @@ -92,7 +89,9 @@ def test_run_with_logging_enabled(caplog, mock_api_key): assert "Sending request to" in caplog.text -@pytest.mark.parametrize("invalid_input", [None, 123, ["list", "of", "items"]]) +@pytest.mark.parametrize( + "invalid_input", [None, 123, ["list", "of", "items"]] +) def test_invalid_task_input(invalid_input, mock_api_key): model = TogetherLLM() response = model.run(invalid_input) @@ -104,11 +103,7 @@ def test_invalid_task_input(invalid_input, mock_api_key): def test_run_streaming_enabled(mock_post, mock_api_key): mock_response = Mock() mock_response.json.return_value = { - "choices": [{ - "message": { - "content": "Generated response" - } - }] + "choices": [{"message": {"content": "Generated response"}}] } mock_post.return_value = mock_response diff --git a/tests/models/test_ultralytics.py b/tests/models/test_ultralytics.py index 30aa64fb..cca1d023 100644 --- a/tests/models/test_ultralytics.py +++ b/tests/models/test_ultralytics.py @@ -20,7 +20,9 @@ def test_ultralytics_call(): args = (1, 2, 3) kwargs = {"a": "A", "b": "B"} result = ultralytics(task, *args, **kwargs) - mock_yolo.return_value.assert_called_once_with(task, *args, **kwargs) + mock_yolo.return_value.assert_called_once_with( + task, *args, **kwargs + ) assert result == mock_yolo.return_value.return_value diff --git a/tests/models/test_vilt.py b/tests/models/test_vilt.py index dd1e47e2..d849f98e 100644 --- a/tests/models/test_vilt.py +++ b/tests/models/test_vilt.py @@ -21,13 +21,17 @@ def test_vilt_initialization(vilt_instance): # 2. Test Model Predictions @patch.object(requests, "get") @patch.object(Image, "open") -def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): +def test_vilt_prediction( + mock_image_open, mock_requests_get, vilt_instance +): mock_image = Mock() mock_image_open.return_value = mock_image mock_requests_get.return_value.raw = Mock() # It's a mock response, so no real answer expected - with pytest.raises(Exception): # Ensure exception is more specific + with pytest.raises( + Exception + ): # Ensure exception is more specific vilt_instance( "What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", @@ -62,7 +66,9 @@ def test_vilt_network_exception(vilt_instance): ], ) def test_vilt_various_inputs(text, image_url, vilt_instance): - with pytest.raises(Exception): # Again, ensure exception is more specific + with pytest.raises( + Exception + ): # Again, ensure exception is more specific vilt_instance(text, image_url) diff --git a/tests/models/test_yi_200k.py b/tests/models/test_yi_200k.py index 9b2741a2..b31daa3e 100644 --- a/tests/models/test_yi_200k.py +++ b/tests/models/test_yi_200k.py @@ -32,7 +32,9 @@ def test_yi34b_generate_text_with_length(yi34b_model, max_length): @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): +def test_yi34b_generate_text_with_temperature( + yi34b_model, temperature +): prompt = "There's a place where time stands still." generated_text = yi34b_model(prompt, temperature=temperature) assert isinstance(generated_text, str) @@ -40,24 +42,27 @@ def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): def test_yi34b_generate_text_with_invalid_prompt(yi34b_model): prompt = None # Invalid prompt - with pytest.raises(ValueError, - match="Input prompt must be a non-empty string"): + with pytest.raises( + ValueError, match="Input prompt must be a non-empty string" + ): yi34b_model(prompt) def test_yi34b_generate_text_with_invalid_max_length(yi34b_model): prompt = "There's a place where time stands still." max_length = -1 # Invalid max_length - with pytest.raises(ValueError, - match="max_length must be a positive integer"): + with pytest.raises( + ValueError, match="max_length must be a positive integer" + ): yi34b_model(prompt, max_length=max_length) def test_yi34b_generate_text_with_invalid_temperature(yi34b_model): prompt = "There's a place where time stands still." temperature = 2.0 # Invalid temperature - with pytest.raises(ValueError, - match="temperature must be between 0.01 and 1.0"): + with pytest.raises( + ValueError, match="temperature must be between 0.01 and 1.0" + ): yi34b_model(prompt, temperature=temperature) @@ -78,32 +83,40 @@ def test_yi34b_generate_text_with_top_p(yi34b_model, top_p): def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): prompt = "There's a place where time stands still." top_k = -1 # Invalid top_k - with pytest.raises(ValueError, - match="top_k must be a non-negative integer"): + with pytest.raises( + ValueError, match="top_k must be a non-negative integer" + ): yi34b_model(prompt, top_k=top_k) def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): prompt = "There's a place where time stands still." top_p = 1.5 # Invalid top_p - with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"): + with pytest.raises( + ValueError, match="top_p must be between 0.0 and 1.0" + ): yi34b_model(prompt, top_p=top_p) @pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5]) -def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, - repitition_penalty): +def test_yi34b_generate_text_with_repitition_penalty( + yi34b_model, repitition_penalty +): prompt = "There's a place where time stands still." - generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) + generated_text = yi34b_model( + prompt, repitition_penalty=repitition_penalty + ) assert isinstance(generated_text, str) -def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model,): +def test_yi34b_generate_text_with_invalid_repitition_penalty( + yi34b_model, +): prompt = "There's a place where time stands still." repitition_penalty = 0.0 # Invalid repitition_penalty with pytest.raises( - ValueError, - match="repitition_penalty must be a positive float", + ValueError, + match="repitition_penalty must be a positive float", ): yi34b_model(prompt, repitition_penalty=repitition_penalty) diff --git a/tests/models/test_zeroscope.py b/tests/models/test_zeroscope.py index f0b130d1..25a4c597 100644 --- a/tests/models/test_zeroscope.py +++ b/tests/models/test_zeroscope.py @@ -25,11 +25,16 @@ def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) - mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") + mock_pipeline.from_pretrained.return_value = ( + mock_pipeline_instance + ) + mock_pipeline_instance.return_value = MagicMock( + frames="Generated frames" + ) mock_pipeline_instance.enable_vae_slicing.assert_called_once() mock_pipeline_instance.enable_forward_chunking.assert_called_once_with( - chunk_size=1, dim=1) + chunk_size=1, dim=1 + ) result = zeroscope.forward("Test task") assert result == "Generated frames" mock_pipeline_instance.assert_called_once_with( @@ -46,8 +51,12 @@ def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) - mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") + mock_pipeline.from_pretrained.return_value = ( + mock_pipeline_instance + ) + mock_pipeline_instance.return_value = MagicMock( + frames="Generated frames" + ) mock_pipeline_instance.side_effect = Exception("Test error") with pytest.raises(Exception, match="Test error"): zeroscope.forward("Test task") @@ -58,8 +67,12 @@ def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) - mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") + mock_pipeline.from_pretrained.return_value = ( + mock_pipeline_instance + ) + mock_pipeline_instance.return_value = MagicMock( + frames="Generated frames" + ) result = zeroscope.__call__("Test task") assert result == "Generated frames" mock_pipeline_instance.assert_called_once_with( @@ -76,8 +89,12 @@ def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) - mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") + mock_pipeline.from_pretrained.return_value = ( + mock_pipeline_instance + ) + mock_pipeline_instance.return_value = MagicMock( + frames="Generated frames" + ) mock_pipeline_instance.side_effect = Exception("Test error") with pytest.raises(Exception, match="Test error"): zeroscope.__call__("Test task") @@ -88,8 +105,12 @@ def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) - mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") + mock_pipeline.from_pretrained.return_value = ( + mock_pipeline_instance + ) + mock_pipeline_instance.return_value = MagicMock( + frames="Generated frames" + ) result = zeroscope.save_video_path("Test video path") assert result == "Test video path" mock_pipeline_instance.assert_called_once_with( diff --git a/tests/structs/test_agent.py b/tests/structs/test_agent.py index bfb3b365..5be7f31a 100644 --- a/tests/structs/test_agent.py +++ b/tests/structs/test_agent.py @@ -18,7 +18,9 @@ openai_api_key = os.getenv("OPENAI_API_KEY") # Mocks and Fixtures @pytest.fixture def mocked_llm(): - return OpenAIChat(openai_api_key=openai_api_key,) + return OpenAIChat( + openai_api_key=openai_api_key, + ) @pytest.fixture @@ -63,12 +65,15 @@ def test_provide_feedback(basic_flow): @patch("time.sleep", return_value=None) # to speed up tests def test_run_without_stopping_condition(mocked_sleep, basic_flow): response = basic_flow.run("Test task") - assert (response == "Test task" - ) # since our mocked llm doesn't modify the response + assert ( + response == "Test task" + ) # since our mocked llm doesn't modify the response @patch("time.sleep", return_value=None) # to speed up tests -def test_run_with_stopping_condition(mocked_sleep, flow_with_condition): +def test_run_with_stopping_condition( + mocked_sleep, flow_with_condition +): response = flow_with_condition.run("Stop") assert response == "Stop" @@ -108,7 +113,6 @@ def test_env_variable_handling(monkeypatch): # Test initializing the agent with different stopping conditions def test_flow_with_custom_stopping_condition(mocked_llm): - def stopping_condition(x): return "terminate" in x.lower() @@ -129,7 +133,9 @@ def test_flow_call(basic_flow): # Test formatting the prompt def test_format_prompt(basic_flow): - formatted_prompt = basic_flow.format_prompt("Hello {name}", name="John") + formatted_prompt = basic_flow.format_prompt( + "Hello {name}", name="John" + ) assert formatted_prompt == "Hello John" @@ -158,15 +164,9 @@ def test_interactive_mode(basic_flow): # Test bulk run with varied inputs def test_bulk_run_varied_inputs(basic_flow): inputs = [ - { - "task": "Test1" - }, - { - "task": "Test2" - }, - { - "task": "Stop now" - }, + {"task": "Test1"}, + {"task": "Test2"}, + {"task": "Stop now"}, ] responses = basic_flow.bulk_run(inputs) assert responses == ["Test1", "Test2", "Stop now"] @@ -191,9 +191,12 @@ def test_save_different_memory(basic_flow, tmp_path): # Test the stopping condition check def test_check_stopping_condition(flow_with_condition): - assert flow_with_condition._check_stopping_condition("Stop this process") + assert flow_with_condition._check_stopping_condition( + "Stop this process" + ) assert not flow_with_condition._check_stopping_condition( - "Continue the task") + "Continue the task" + ) # Test without providing max loops (default value should be 5) @@ -249,7 +252,9 @@ def test_different_retry_intervals(mocked_sleep, basic_flow): # Test invoking the agent with additional kwargs @patch("time.sleep", return_value=None) def test_flow_call_with_kwargs(mocked_sleep, basic_flow): - response = basic_flow("Test call", param1="value1", param2="value2") + response = basic_flow( + "Test call", param1="value1", param2="value2" + ) assert response == "Test call" @@ -284,7 +289,9 @@ def test_stopping_token_in_response(mocked_sleep, basic_flow): def flow_instance(): # Create an instance of the Agent class with required parameters for testing # You may need to adjust this based on your actual class initialization - llm = OpenAIChat(openai_api_key=openai_api_key,) + llm = OpenAIChat( + openai_api_key=openai_api_key, + ) agent = Agent( llm=llm, max_loops=5, @@ -331,7 +338,9 @@ def test_flow_autosave(flow_instance): def test_flow_response_filtering(flow_instance): # Test the response filtering functionality flow_instance.add_response_filter("filter_this") - response = flow_instance.filtered_run("This message should filter_this") + response = flow_instance.filtered_run( + "This message should filter_this" + ) assert "filter_this" not in response @@ -391,8 +400,11 @@ def test_flow_response_length(flow_instance): # Test checking the length of the response response = flow_instance.run( "Generate a 10,000 word long blog on mental clarity and the" - " benefits of meditation.") - assert (len(response) > flow_instance.get_response_length_threshold()) + " benefits of meditation." + ) + assert ( + len(response) > flow_instance.get_response_length_threshold() + ) def test_flow_set_response_length_threshold(flow_instance): @@ -481,7 +493,9 @@ def test_flow_get_conversation_log(flow_instance): flow_instance.run("Message 1") flow_instance.run("Message 2") conversation_log = flow_instance.get_conversation_log() - assert (len(conversation_log) == 4) # Including system and user messages + assert ( + len(conversation_log) == 4 + ) # Including system and user messages def test_flow_clear_conversation_log(flow_instance): @@ -565,18 +579,37 @@ def test_flow_rollback(flow_instance): flow_instance.change_prompt("New prompt") flow_instance.get_state() flow_instance.rollback_to_state(state1) - assert (flow_instance.get_current_prompt() == state1["current_prompt"]) + assert ( + flow_instance.get_current_prompt() == state1["current_prompt"] + ) assert flow_instance.get_instructions() == state1["instructions"] - assert (flow_instance.get_user_messages() == state1["user_messages"]) - assert (flow_instance.get_response_history() == state1["response_history"]) - assert (flow_instance.get_conversation_log() == state1["conversation_log"]) - assert (flow_instance.is_dynamic_pacing_enabled() == - state1["dynamic_pacing_enabled"]) - assert (flow_instance.get_response_length_threshold() == - state1["response_length_threshold"]) - assert (flow_instance.get_response_filters() == state1["response_filters"]) + assert ( + flow_instance.get_user_messages() == state1["user_messages"] + ) + assert ( + flow_instance.get_response_history() + == state1["response_history"] + ) + assert ( + flow_instance.get_conversation_log() + == state1["conversation_log"] + ) + assert ( + flow_instance.is_dynamic_pacing_enabled() + == state1["dynamic_pacing_enabled"] + ) + assert ( + flow_instance.get_response_length_threshold() + == state1["response_length_threshold"] + ) + assert ( + flow_instance.get_response_filters() + == state1["response_filters"] + ) assert flow_instance.get_max_loops() == state1["max_loops"] - assert (flow_instance.get_autosave_path() == state1["autosave_path"]) + assert ( + flow_instance.get_autosave_path() == state1["autosave_path"] + ) assert flow_instance.get_state() == state1 @@ -585,7 +618,8 @@ def test_flow_contextual_intent(flow_instance): flow_instance.add_context("location", "New York") flow_instance.add_context("time", "tomorrow") response = flow_instance.run( - "What's the weather like in {location} at {time}?") + "What's the weather like in {location} at {time}?" + ) assert "New York" in response assert "tomorrow" in response @@ -593,9 +627,13 @@ def test_flow_contextual_intent(flow_instance): def test_flow_contextual_intent_override(flow_instance): # Test contextual intent override flow_instance.add_context("location", "New York") - response1 = flow_instance.run("What's the weather like in {location}?") + response1 = flow_instance.run( + "What's the weather like in {location}?" + ) flow_instance.add_context("location", "Los Angeles") - response2 = flow_instance.run("What's the weather like in {location}?") + response2 = flow_instance.run( + "What's the weather like in {location}?" + ) assert "New York" in response1 assert "Los Angeles" in response2 @@ -603,9 +641,13 @@ def test_flow_contextual_intent_override(flow_instance): def test_flow_contextual_intent_reset(flow_instance): # Test resetting contextual intent flow_instance.add_context("location", "New York") - response1 = flow_instance.run("What's the weather like in {location}?") + response1 = flow_instance.run( + "What's the weather like in {location}?" + ) flow_instance.reset_context() - response2 = flow_instance.run("What's the weather like in {location}?") + response2 = flow_instance.run( + "What's the weather like in {location}?" + ) assert "New York" in response1 assert "New York" in response2 @@ -630,7 +672,9 @@ def test_flow_non_interruptible(flow_instance): def test_flow_timeout(flow_instance): # Test conversation timeout flow_instance.timeout = 60 # Set a timeout of 60 seconds - response = flow_instance.run("This should take some time to respond.") + response = flow_instance.run( + "This should take some time to respond." + ) assert "Timed out" in response assert flow_instance.is_timed_out() is True @@ -679,14 +723,20 @@ def test_flow_save_and_load_conversation(flow_instance): def test_flow_inject_custom_system_message(flow_instance): # Test injecting a custom system message into the conversation - flow_instance.inject_custom_system_message("Custom system message") - assert ("Custom system message" in flow_instance.get_message_history()) + flow_instance.inject_custom_system_message( + "Custom system message" + ) + assert ( + "Custom system message" in flow_instance.get_message_history() + ) def test_flow_inject_custom_user_message(flow_instance): # Test injecting a custom user message into the conversation flow_instance.inject_custom_user_message("Custom user message") - assert ("Custom user message" in flow_instance.get_message_history()) + assert ( + "Custom user message" in flow_instance.get_message_history() + ) def test_flow_inject_custom_response(flow_instance): @@ -697,28 +747,45 @@ def test_flow_inject_custom_response(flow_instance): def test_flow_clear_injected_messages(flow_instance): # Test clearing injected messages from the conversation - flow_instance.inject_custom_system_message("Custom system message") + flow_instance.inject_custom_system_message( + "Custom system message" + ) flow_instance.inject_custom_user_message("Custom user message") flow_instance.inject_custom_response("Custom response") flow_instance.clear_injected_messages() - assert ("Custom system message" not in flow_instance.get_message_history()) - assert ("Custom user message" not in flow_instance.get_message_history()) - assert ("Custom response" not in flow_instance.get_message_history()) + assert ( + "Custom system message" + not in flow_instance.get_message_history() + ) + assert ( + "Custom user message" + not in flow_instance.get_message_history() + ) + assert ( + "Custom response" not in flow_instance.get_message_history() + ) def test_flow_disable_message_history(flow_instance): # Test disabling message history recording flow_instance.disable_message_history() response = flow_instance.run( - "This message should not be recorded in history.") - assert ("This message should not be recorded in history." in response) - assert (len(flow_instance.get_message_history()) == 0) # History is empty + "This message should not be recorded in history." + ) + assert ( + "This message should not be recorded in history." in response + ) + assert ( + len(flow_instance.get_message_history()) == 0 + ) # History is empty def test_flow_enable_message_history(flow_instance): # Test enabling message history recording flow_instance.enable_message_history() - response = flow_instance.run("This message should be recorded in history.") + response = flow_instance.run( + "This message should be recorded in history." + ) assert "This message should be recorded in history." in response assert len(flow_instance.get_message_history()) == 1 @@ -728,7 +795,9 @@ def test_flow_custom_logger(flow_instance): custom_logger = logger # Replace with your custom logger class flow_instance.set_logger(custom_logger) response = flow_instance.run("Custom logger test") - assert ("Logged using custom logger" in response) # Verify logging message + assert ( + "Logged using custom logger" in response + ) # Verify logging message def test_flow_batch_processing(flow_instance): @@ -802,35 +871,43 @@ def test_flow_input_validation(flow_instance): with pytest.raises(ValueError): flow_instance.set_message_delimiter( - "") # Empty delimiter, should raise ValueError + "" + ) # Empty delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_message_delimiter( - None) # None delimiter, should raise ValueError + None + ) # None delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_message_delimiter( - 123) # Invalid delimiter type, should raise ValueError + 123 + ) # Invalid delimiter type, should raise ValueError with pytest.raises(ValueError): flow_instance.set_logger( - "invalid_logger") # Invalid logger type, should raise ValueError + "invalid_logger" + ) # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context(None, - "value") # None key, should raise ValueError + flow_instance.add_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context("key", - None) # None value, should raise ValueError + flow_instance.add_context( + "key", None + ) # None value, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - None, "value") # None key, should raise ValueError + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - "key", None) # None value, should raise ValueError + "key", None + ) # None value, should raise ValueError def test_flow_conversation_reset(flow_instance): @@ -857,7 +934,6 @@ def test_flow_conversation_persistence(flow_instance): def test_flow_custom_event_listener(flow_instance): # Test custom event listener class CustomEventListener: - def on_message_received(self, message): pass @@ -869,10 +945,10 @@ def test_flow_custom_event_listener(flow_instance): # Ensure that the custom event listener methods are called during a conversation with mock.patch.object( - custom_event_listener, - "on_message_received") as mock_received, mock.patch.object( - custom_event_listener, - "on_response_generated") as mock_response: + custom_event_listener, "on_message_received" + ) as mock_received, mock.patch.object( + custom_event_listener, "on_response_generated" + ) as mock_response: flow_instance.run("Message 1") mock_received.assert_called_once() mock_response.assert_called_once() @@ -881,7 +957,6 @@ def test_flow_custom_event_listener(flow_instance): def test_flow_multiple_event_listeners(flow_instance): # Test multiple event listeners class FirstEventListener: - def on_message_received(self, message): pass @@ -889,7 +964,6 @@ def test_flow_multiple_event_listeners(flow_instance): pass class SecondEventListener: - def on_message_received(self, message): pass @@ -903,14 +977,14 @@ def test_flow_multiple_event_listeners(flow_instance): # Ensure that both event listeners receive events during a conversation with mock.patch.object( - first_event_listener, - "on_message_received") as mock_first_received, mock.patch.object( - first_event_listener, "on_response_generated" - ) as mock_first_response, mock.patch.object( - second_event_listener, "on_message_received" - ) as mock_second_received, mock.patch.object( - second_event_listener, - "on_response_generated") as mock_second_response: + first_event_listener, "on_message_received" + ) as mock_first_received, mock.patch.object( + first_event_listener, "on_response_generated" + ) as mock_first_response, mock.patch.object( + second_event_listener, "on_message_received" + ) as mock_second_received, mock.patch.object( + second_event_listener, "on_response_generated" + ) as mock_second_response: flow_instance.run("Message 1") mock_first_received.assert_called_once() mock_first_response.assert_called_once() @@ -923,31 +997,38 @@ def test_flow_error_handling(flow_instance): # Test error handling and exceptions with pytest.raises(ValueError): flow_instance.set_message_delimiter( - "") # Empty delimiter, should raise ValueError + "" + ) # Empty delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_message_delimiter( - None) # None delimiter, should raise ValueError + None + ) # None delimiter, should raise ValueError with pytest.raises(ValueError): flow_instance.set_logger( - "invalid_logger") # Invalid logger type, should raise ValueError + "invalid_logger" + ) # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context(None, - "value") # None key, should raise ValueError + flow_instance.add_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context("key", - None) # None value, should raise ValueError + flow_instance.add_context( + "key", None + ) # None value, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - None, "value") # None key, should raise ValueError + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): flow_instance.update_context( - "key", None) # None value, should raise ValueError + "key", None + ) # None value, should raise ValueError def test_flow_context_operations(flow_instance): @@ -984,8 +1065,14 @@ def test_flow_custom_response(flow_instance): flow_instance.set_response_generator(custom_response_generator) assert flow_instance.run("Hello") == "Hi there!" - assert (flow_instance.run("How are you?") == "I'm doing well, thank you.") - assert (flow_instance.run("What's your name?") == "I don't understand.") + assert ( + flow_instance.run("How are you?") + == "I'm doing well, thank you." + ) + assert ( + flow_instance.run("What's your name?") + == "I don't understand." + ) def test_flow_message_validation(flow_instance): @@ -996,8 +1083,12 @@ def test_flow_message_validation(flow_instance): flow_instance.set_message_validator(custom_message_validator) assert flow_instance.run("Valid message") is not None - assert (flow_instance.run("") is None) # Empty message should be rejected - assert (flow_instance.run(None) is None) # None message should be rejected + assert ( + flow_instance.run("") is None + ) # Empty message should be rejected + assert ( + flow_instance.run(None) is None + ) # None message should be rejected def test_flow_custom_logging(flow_instance): @@ -1022,10 +1113,15 @@ def test_flow_complex_use_case(flow_instance): flow_instance.add_context("user_id", "12345") flow_instance.run("Hello") flow_instance.run("How can I help you?") - assert (flow_instance.get_response() == "Please provide more details.") + assert ( + flow_instance.get_response() == "Please provide more details." + ) flow_instance.update_context("user_id", "54321") flow_instance.run("I need help with my order") - assert (flow_instance.get_response() == "Sure, I can assist with that.") + assert ( + flow_instance.get_response() + == "Sure, I can assist with that." + ) flow_instance.reset_conversation() assert len(flow_instance.get_message_history()) == 0 assert flow_instance.get_context("user_id") is None @@ -1064,7 +1160,9 @@ def test_flow_concurrent_requests(flow_instance): def test_flow_custom_timeout(flow_instance): # Test custom timeout handling - flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds + flow_instance.set_timeout( + 10 + ) # Set a custom timeout of 10 seconds assert flow_instance.get_timeout() == 10 import time @@ -1115,10 +1213,16 @@ def test_flow_agent_history_prompt(flow_instance): history = ["User: Hi", "AI: Hello"] agent_history_prompt = flow_instance.agent_history_prompt( - system_prompt, history) + system_prompt, history + ) - assert ("SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt) - assert ("History: ['User: Hi', 'AI: Hello']" in agent_history_prompt) + assert ( + "SYSTEM_PROMPT: This is the system prompt." + in agent_history_prompt + ) + assert ( + "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt + ) async def test_flow_run_concurrent(flow_instance): @@ -1133,18 +1237,9 @@ async def test_flow_run_concurrent(flow_instance): def test_flow_bulk_run(flow_instance): # Test bulk running of tasks input_data = [ - { - "task": "Task 1", - "param1": "value1" - }, - { - "task": "Task 2", - "param2": "value2" - }, - { - "task": "Task 3", - "param3": "value3" - }, + {"task": "Task 1", "param1": "value1"}, + {"task": "Task 2", "param2": "value2"}, + {"task": "Task 3", "param3": "value3"}, ] responses = flow_instance.bulk_run(input_data) @@ -1159,7 +1254,9 @@ def test_flow_from_llm_and_template(): llm_instance = mocked_llm # Replace with your LLM class template = "This is a template for testing." - flow_instance = Agent.from_llm_and_template(llm_instance, template) + flow_instance = Agent.from_llm_and_template( + llm_instance, template + ) assert isinstance(flow_instance, Agent) @@ -1168,10 +1265,12 @@ def test_flow_from_llm_and_template_file(): # Test creating Agent instance from an LLM and a template file llm_instance = mocked_llm # Replace with your LLM class template_file = ( # Create a template file for testing - "template.txt") + "template.txt" + ) - flow_instance = Agent.from_llm_and_template_file(llm_instance, - template_file) + flow_instance = Agent.from_llm_and_template_file( + llm_instance, template_file + ) assert isinstance(flow_instance, Agent) diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py index b0d00606..2e5585bf 100644 --- a/tests/structs/test_autoscaler.py +++ b/tests/structs/test_autoscaler.py @@ -44,7 +44,9 @@ def test_autoscaler_run(): agent.id, "Generate a 10,000 word blog on health and wellness.", ) - assert (out == "Generate a 10,000 word blog on health and wellness.") + assert ( + out == "Generate a 10,000 word blog on health and wellness." + ) def test_autoscaler_add_agent(): @@ -237,7 +239,9 @@ def test_autoscaler_add_task(): def test_autoscaler_scale_up(): - autoscaler = AutoScaler(initial_agents=5, scale_up_factor=2, agent=agent) + autoscaler = AutoScaler( + initial_agents=5, scale_up_factor=2, agent=agent + ) autoscaler.scale_up() assert len(autoscaler.agents_pool) == 10 diff --git a/tests/structs/test_base.py b/tests/structs/test_base.py index 3cbe5c3d..971f966b 100644 --- a/tests/structs/test_base.py +++ b/tests/structs/test_base.py @@ -7,7 +7,6 @@ from swarms.structs.base import BaseStructure class TestBaseStructure: - def test_init(self): base_structure = BaseStructure( name="TestStructure", @@ -89,8 +88,11 @@ class TestBaseStructure: with open(log_file) as file: lines = file.readlines() assert len(lines) == 1 - assert (lines[0] == f"[{base_structure._current_timestamp()}]" - f" [{event_type}] {event}\n") + assert ( + lines[0] + == f"[{base_structure._current_timestamp()}]" + f" [{event_type}] {event}\n" + ) @pytest.mark.asyncio async def test_run_async(self): @@ -134,7 +136,9 @@ class TestBaseStructure: artifact = {"key": "value"} artifact_name = "test_artifact" - await base_structure.save_artifact_async(artifact, artifact_name) + await base_structure.save_artifact_async( + artifact, artifact_name + ) loaded_artifact = base_structure.load_artifact(artifact_name) assert loaded_artifact == artifact @@ -147,8 +151,9 @@ class TestBaseStructure: artifact = {"key": "value"} artifact_name = "test_artifact" base_structure.save_artifact(artifact, artifact_name) - loaded_artifact = await base_structure.load_artifact_async(artifact_name - ) + loaded_artifact = await base_structure.load_artifact_async( + artifact_name + ) assert loaded_artifact == artifact @@ -165,8 +170,11 @@ class TestBaseStructure: with open(log_file) as file: lines = file.readlines() assert len(lines) == 1 - assert (lines[0] == f"[{base_structure._current_timestamp()}]" - f" [{event_type}] {event}\n") + assert ( + lines[0] + == f"[{base_structure._current_timestamp()}]" + f" [{event_type}] {event}\n" + ) @pytest.mark.asyncio async def test_asave_to_file(self, tmpdir): @@ -193,14 +201,18 @@ class TestBaseStructure: def test_run_in_thread(self): base_structure = BaseStructure() - result = base_structure.run_in_thread(lambda: "Thread Test Result") + result = base_structure.run_in_thread( + lambda: "Thread Test Result" + ) assert result.result() == "Thread Test Result" def test_save_and_decompress_data(self): base_structure = BaseStructure() data = {"key": "value"} compressed_data = base_structure.compress_data(data) - decompressed_data = base_structure.decompres_data(compressed_data) + decompressed_data = base_structure.decompres_data( + compressed_data + ) assert decompressed_data == data def test_run_batched(self): @@ -210,11 +222,13 @@ class TestBaseStructure: return f"Processed {data}" batched_data = list(range(10)) - result = base_structure.run_batched(batched_data, - batch_size=5, - func=run_function) + result = base_structure.run_batched( + batched_data, batch_size=5, func=run_function + ) - expected_result = [f"Processed {data}" for data in batched_data] + expected_result = [ + f"Processed {data}" for data in batched_data + ] assert result == expected_result def test_load_config(self, tmpdir): @@ -232,12 +246,15 @@ class TestBaseStructure: tmp_dir = tmpdir.mkdir("test_dir") base_structure = BaseStructure() data_to_backup = {"key": "value"} - base_structure.backup_data(data_to_backup, backup_path=tmp_dir) + base_structure.backup_data( + data_to_backup, backup_path=tmp_dir + ) backup_files = os.listdir(tmp_dir) assert len(backup_files) == 1 loaded_data = base_structure.load_from_file( - os.path.join(tmp_dir, backup_files[0])) + os.path.join(tmp_dir, backup_files[0]) + ) assert loaded_data == data_to_backup def test_monitor_resources(self): @@ -262,9 +279,11 @@ class TestBaseStructure: return f"Processed {data}" batched_data = list(range(10)) - result = base_structure.run_with_resources_batched(batched_data, - batch_size=5, - func=run_function) + result = base_structure.run_with_resources_batched( + batched_data, batch_size=5, func=run_function + ) - expected_result = [f"Processed {data}" for data in batched_data] + expected_result = [ + f"Processed {data}" for data in batched_data + ] assert result == expected_result diff --git a/tests/structs/test_base_workflow.py b/tests/structs/test_base_workflow.py index 2d94caf2..ccb7a563 100644 --- a/tests/structs/test_base_workflow.py +++ b/tests/structs/test_base_workflow.py @@ -30,8 +30,13 @@ def test_load_workflow_state(): workflow.load_workflow_state("workflow_state.json") assert workflow.max_loops == 1 assert len(workflow.tasks) == 2 - assert (workflow.tasks[0].description == "What's the weather in miami") - assert (workflow.tasks[1].description == "Create a report on these metrics") + assert ( + workflow.tasks[0].description == "What's the weather in miami" + ) + assert ( + workflow.tasks[1].description + == "Create a report on these metrics" + ) teardown_workflow() diff --git a/tests/structs/test_concurrent_workflow.py b/tests/structs/test_concurrent_workflow.py index 9a3f46da..e3fabdd5 100644 --- a/tests/structs/test_concurrent_workflow.py +++ b/tests/structs/test_concurrent_workflow.py @@ -18,7 +18,9 @@ def test_run(): workflow.add(task1) workflow.add(task2) - with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor: + with patch( + "concurrent.futures.ThreadPoolExecutor" + ) as mock_executor: future1 = Future() future1.set_result(None) future2 = Future() diff --git a/tests/structs/test_conversation.py b/tests/structs/test_conversation.py index 67f083f3..049f3fb3 100644 --- a/tests/structs/test_conversation.py +++ b/tests/structs/test_conversation.py @@ -87,13 +87,16 @@ def test_return_history_as_string_with_different_roles(role, content): @pytest.mark.parametrize("message_count", range(1, 11)) -def test_return_history_as_string_with_multiple_messages(message_count,): +def test_return_history_as_string_with_multiple_messages( + message_count, +): conv = Conversation() for i in range(message_count): conv.add("user", f"Message {i + 1}") result = conv.return_history_as_string() expected = "".join( - [f"user: Message {i + 1}\n\n" for i in range(message_count)]) + [f"user: Message {i + 1}\n\n" for i in range(message_count)] + ) assert result == expected @@ -119,8 +122,10 @@ def test_return_history_as_string_with_large_message(conversation): large_message = "Hello, world! " * 10000 # 10,000 repetitions conversation.add("user", large_message) result = conversation.return_history_as_string() - expected = ("user: Hello, world!\n\nassistant: Hello, user!\n\nuser:" - f" {large_message}\n\n") + expected = ( + "user: Hello, world!\n\nassistant: Hello, user!\n\nuser:" + f" {large_message}\n\n" + ) assert result == expected @@ -136,8 +141,10 @@ def test_export_import_conversation(conversation, tmp_path): conversation.export_conversation(filename) new_conversation = Conversation() new_conversation.import_conversation(filename) - assert (new_conversation.return_history_as_string() == - conversation.return_history_as_string()) + assert ( + new_conversation.return_history_as_string() + == conversation.return_history_as_string() + ) def test_count_messages_by_role(conversation): diff --git a/tests/structs/test_groupchat.py b/tests/structs/test_groupchat.py index d9635e79..e8096d9c 100644 --- a/tests/structs/test_groupchat.py +++ b/tests/structs/test_groupchat.py @@ -11,7 +11,6 @@ llm2 = Anthropic() # Mock the OpenAI class for testing class MockOpenAI: - def __init__(self, *args, **kwargs): pass @@ -140,9 +139,9 @@ def test_groupchat_manager_generate_reply(): selector = agent1 # Initialize GroupChatManager - manager = GroupChatManager(groupchat=groupchat, - selector=selector, - openai=mocked_openai) + manager = GroupChatManager( + groupchat=groupchat, selector=selector, openai=mocked_openai + ) # Generate a reply task = "Write me a riddle" @@ -166,8 +165,9 @@ def test_groupchat_select_speaker(): # Simulate selecting the next speaker last_speaker = agent1 - next_speaker = manager.select_speaker(last_speaker=last_speaker, - selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is agent2 assert next_speaker == agent2 @@ -185,8 +185,9 @@ def test_groupchat_underpopulated_group(): # Simulate selecting the next speaker in an underpopulated group last_speaker = agent1 - next_speaker = manager.select_speaker(last_speaker=last_speaker, - selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is the same as the last speaker in an underpopulated group assert next_speaker == last_speaker @@ -204,13 +205,15 @@ def test_groupchat_max_rounds(): # Simulate the conversation with max rounds last_speaker = agent1 for _ in range(2): - next_speaker = manager.select_speaker(last_speaker=last_speaker, - selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) last_speaker = next_speaker # Try one more round, should stay with the last speaker - next_speaker = manager.select_speaker(last_speaker=last_speaker, - selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is the same as the last speaker after reaching max rounds assert next_speaker == last_speaker diff --git a/tests/structs/test_json.py b/tests/structs/test_json.py index c0e1e851..9ba11072 100644 --- a/tests/structs/test_json.py +++ b/tests/structs/test_json.py @@ -15,8 +15,10 @@ def valid_schema_path(tmp_path): d = tmp_path / "sub" d.mkdir() p = d / "schema.json" - p.write_text('{"type": "object", "properties": {"name": {"type":' - ' "string"}}}') + p.write_text( + '{"type": "object", "properties": {"name": {"type":' + ' "string"}}}' + ) return str(p) @@ -31,7 +33,6 @@ def invalid_schema_path(tmp_path): # This test class must be subclassed as JSON class is abstract class TestableJSON(JSON): - def validate(self, data): # Here must be a real validation implementation for testing pass diff --git a/tests/structs/test_majority_voting.py b/tests/structs/test_majority_voting.py index b6f09020..dcd25f0b 100644 --- a/tests/structs/test_majority_voting.py +++ b/tests/structs/test_majority_voting.py @@ -35,9 +35,15 @@ def test_majority_voting_run_concurrent(mocker): majority_vote = mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with("What is the capital of France?") - agent2.run.assert_called_once_with("What is the capital of France?") - agent3.run.assert_called_once_with("What is the capital of France?") + agent1.run.assert_called_once_with( + "What is the capital of France?" + ) + agent2.run.assert_called_once_with( + "What is the capital of France?" + ) + agent3.run.assert_called_once_with( + "What is the capital of France?" + ) # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) @@ -77,9 +83,15 @@ def test_majority_voting_run_multithreaded(mocker): majority_vote = mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with("What is the capital of France?") - agent2.run.assert_called_once_with("What is the capital of France?") - agent3.run.assert_called_once_with("What is the capital of France?") + agent1.run.assert_called_once_with( + "What is the capital of France?" + ) + agent2.run.assert_called_once_with( + "What is the capital of France?" + ) + agent3.run.assert_called_once_with( + "What is the capital of France?" + ) # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) @@ -121,9 +133,15 @@ async def test_majority_voting_run_asynchronous(mocker): majority_vote = await mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with("What is the capital of France?") - agent2.run.assert_called_once_with("What is the capital of France?") - agent3.run.assert_called_once_with("What is the capital of France?") + agent1.run.assert_called_once_with( + "What is the capital of France?" + ) + agent2.run.assert_called_once_with( + "What is the capital of France?" + ) + agent3.run.assert_called_once_with( + "What is the capital of France?" + ) # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) diff --git a/tests/structs/test_message_pool.py b/tests/structs/test_message_pool.py index 4dc49a2e..cfbb4df5 100644 --- a/tests/structs/test_message_pool.py +++ b/tests/structs/test_message_pool.py @@ -8,7 +8,9 @@ def test_message_pool_initialization(): agent2 = Agent(llm=OpenAIChat(), agent_name="agent1") moderator = Agent(llm=OpenAIChat(), agent_name="agent1") agents = [agent1, agent2] - message_pool = MessagePool(agents=agents, moderator=moderator, turns=5) + message_pool = MessagePool( + agents=agents, moderator=moderator, turns=5 + ) assert message_pool.agent == agents assert message_pool.moderator == moderator @@ -18,21 +20,27 @@ def test_message_pool_initialization(): def test_message_pool_add(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) + message_pool = MessagePool( + agents=[agent1], moderator=agent1, turns=5 + ) message_pool.add(agent=agent1, content="Hello, world!", turn=1) - assert message_pool.messages == [{ - "agent": agent1, - "content": "Hello, world!", - "turn": 1, - "visible_to": "all", - "logged": True, - }] + assert message_pool.messages == [ + { + "agent": agent1, + "content": "Hello, world!", + "turn": 1, + "visible_to": "all", + "logged": True, + } + ] def test_message_pool_reset(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) + message_pool = MessagePool( + agents=[agent1], moderator=agent1, turns=5 + ) message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.reset() @@ -41,7 +49,9 @@ def test_message_pool_reset(): def test_message_pool_last_turn(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) + message_pool = MessagePool( + agents=[agent1], moderator=agent1, turns=5 + ) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.last_turn() == 1 @@ -49,7 +59,9 @@ def test_message_pool_last_turn(): def test_message_pool_last_message(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) + message_pool = MessagePool( + agents=[agent1], moderator=agent1, turns=5 + ) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.last_message == { @@ -63,24 +75,28 @@ def test_message_pool_last_message(): def test_message_pool_get_all_messages(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) + message_pool = MessagePool( + agents=[agent1], moderator=agent1, turns=5 + ) message_pool.add(agent=agent1, content="Hello, world!", turn=1) - assert message_pool.get_all_messages() == [{ - "agent": agent1, - "content": "Hello, world!", - "turn": 1, - "visible_to": "all", - "logged": True, - }] + assert message_pool.get_all_messages() == [ + { + "agent": agent1, + "content": "Hello, world!", + "turn": 1, + "visible_to": "all", + "logged": True, + } + ] def test_message_pool_get_visible_messages(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent2 = Agent(agent_name="agent2") - message_pool = MessagePool(agents=[agent1, agent2], - moderator=agent1, - turns=5) + message_pool = MessagePool( + agents=[agent1, agent2], moderator=agent1, turns=5 + ) message_pool.add( agent=agent1, content="Hello, agent2!", @@ -88,10 +104,14 @@ def test_message_pool_get_visible_messages(): visible_to=[agent2.agent_name], ) - assert message_pool.get_visible_messages(agent=agent2, turn=2) == [{ - "agent": agent1, - "content": "Hello, agent2!", - "turn": 1, - "visible_to": [agent2.agent_name], - "logged": True, - }] + assert message_pool.get_visible_messages( + agent=agent2, turn=2 + ) == [ + { + "agent": agent1, + "content": "Hello, agent2!", + "turn": 1, + "visible_to": [agent2.agent_name], + "logged": True, + } + ] diff --git a/tests/structs/test_model_parallizer.py b/tests/structs/test_model_parallizer.py index 54abe031..a0840608 100644 --- a/tests/structs/test_model_parallizer.py +++ b/tests/structs/test_model_parallizer.py @@ -11,9 +11,7 @@ from swarms.structs.model_parallizer import ModelParallelizer # Initialize the models custom_config = { "quantize": True, - "quantization_config": { - "load_in_4bit": True - }, + "quantization_config": {"load_in_4bit": True}, "verbose": True, } huggingface_llm = HuggingfaceLLM( @@ -26,12 +24,14 @@ zeroscope_ttv = ZeroscopeTTV() def test_init(): - mp = ModelParallelizer([ - huggingface_llm, - mixtral, - gpt4_vision_api, - zeroscope_ttv, - ]) + mp = ModelParallelizer( + [ + huggingface_llm, + mixtral, + gpt4_vision_api, + zeroscope_ttv, + ] + ) assert isinstance(mp, ModelParallelizer) @@ -39,20 +39,24 @@ def test_run(): mp = ModelParallelizer([huggingface_llm]) result = mp.run( "Create a list of known biggest risks of structural collapse" - " with references") + " with references" + ) assert isinstance(result, str) def test_run_all(): - mp = ModelParallelizer([ - huggingface_llm, - mixtral, - gpt4_vision_api, - zeroscope_ttv, - ]) + mp = ModelParallelizer( + [ + huggingface_llm, + mixtral, + gpt4_vision_api, + zeroscope_ttv, + ] + ) result = mp.run_all( "Create a list of known biggest risks of structural collapse" - " with references") + " with references" + ) assert isinstance(result, list) assert len(result) == 5 @@ -71,8 +75,10 @@ def test_remove_llm(): def test_save_responses_to_file(tmp_path): mp = ModelParallelizer([huggingface_llm]) - mp.run("Create a list of known biggest risks of structural collapse" - " with references") + mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) file = tmp_path / "responses.txt" mp.save_responses_to_file(file) assert file.read_text() != "" @@ -80,8 +86,10 @@ def test_save_responses_to_file(tmp_path): def test_get_task_history(): mp = ModelParallelizer([huggingface_llm]) - mp.run("Create a list of known biggest risks of structural collapse" - " with references") + mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) assert mp.get_task_history() == [ "Create a list of known biggest risks of structural collapse" " with references" @@ -90,8 +98,10 @@ def test_get_task_history(): def test_summary(capsys): mp = ModelParallelizer([huggingface_llm]) - mp.run("Create a list of known biggest risks of structural collapse" - " with references") + mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) mp.summary() captured = capsys.readouterr() assert "Tasks History:" in captured.out @@ -113,7 +123,8 @@ def test_concurrent_run(): mp = ModelParallelizer([huggingface_llm, mixtral]) result = mp.concurrent_run( "Create a list of known biggest risks of structural collapse" - " with references") + " with references" + ) assert isinstance(result, list) assert len(result) == 2 diff --git a/tests/structs/test_multi_agent_collab.py b/tests/structs/test_multi_agent_collab.py index df710f29..555771e7 100644 --- a/tests/structs/test_multi_agent_collab.py +++ b/tests/structs/test_multi_agent_collab.py @@ -73,8 +73,12 @@ def test_run(collaboration): def test_format_results(collaboration): - collaboration.results = [{"agent": "Agent1", "response": "Response1"}] - formatted_results = collaboration.format_results(collaboration.results) + collaboration.results = [ + {"agent": "Agent1", "response": "Response1"} + ] + formatted_results = collaboration.format_results( + collaboration.results + ) assert "Agent1 responded: Response1" in formatted_results @@ -108,10 +112,7 @@ def test_repr(collaboration): def test_load(collaboration): state = { "step": 5, - "results": [{ - "agent": "Agent1", - "response": "Response1" - }], + "results": [{"agent": "Agent1", "response": "Response1"}], } with open(collaboration.saved_file_path_name, "w") as file: json.dump(state, file) diff --git a/tests/structs/test_nonlinear_workflow.py b/tests/structs/test_nonlinear_workflow.py index e45e86cc..2544a7e4 100644 --- a/tests/structs/test_nonlinear_workflow.py +++ b/tests/structs/test_nonlinear_workflow.py @@ -5,7 +5,6 @@ from swarms.structs import NonlinearWorkflow, Task class TestNonlinearWorkflow: - def test_add_task(self): llm = OpenAIChat(openai_api_key="") task = Task(llm, "What's the weather in miami") @@ -34,7 +33,9 @@ class TestNonlinearWorkflow: workflow = NonlinearWorkflow() workflow.add(task1, task2.name) workflow.add(task2, task1.name) - with pytest.raises(Exception, match="Circular dependency detected"): + with pytest.raises( + Exception, match="Circular dependency detected" + ): workflow.run() def test_run_with_stopping_token(self): diff --git a/tests/structs/test_recursive_workflow.py b/tests/structs/test_recursive_workflow.py index 618f955a..5b24f921 100644 --- a/tests/structs/test_recursive_workflow.py +++ b/tests/structs/test_recursive_workflow.py @@ -53,7 +53,9 @@ def test_run_stop_token_not_in_result(): try: workflow.run() except RecursionError: - pytest.fail("RecursiveWorkflow.run caused a RecursionError") + pytest.fail( + "RecursiveWorkflow.run caused a RecursionError" + ) assert agent.execute.call_count == max_iterations diff --git a/tests/structs/test_sequential_workflow.py b/tests/structs/test_sequential_workflow.py index 1b17b305..0d12991a 100644 --- a/tests/structs/test_sequential_workflow.py +++ b/tests/structs/test_sequential_workflow.py @@ -17,7 +17,6 @@ os.environ["OPENAI_API_KEY"] = "mocked_api_key" # Mock OpenAIChat class for testing class MockOpenAIChat: - def __init__(self, *args, **kwargs): pass @@ -27,7 +26,6 @@ class MockOpenAIChat: # Mock Agent class for testing class MockAgent: - def __init__(self, *args, **kwargs): pass @@ -37,7 +35,6 @@ class MockAgent: # Mock SequentialWorkflow class for testing class MockSequentialWorkflow: - def __init__(self, *args, **kwargs): pass @@ -72,7 +69,10 @@ def test_sequential_workflow_initialization(): assert len(workflow.tasks) == 0 assert workflow.max_loops == 1 assert workflow.autosave is False - assert (workflow.saved_state_filepath == "sequential_workflow_state.json") + assert ( + workflow.saved_state_filepath + == "sequential_workflow_state.json" + ) assert workflow.restore_state_filepath is None assert workflow.dashboard is False @@ -177,7 +177,6 @@ def test_sequential_workflow_workflow_dashboard(capfd): # Mock Agent class for async testing class MockAsyncAgent: - def __init__(self, *args, **kwargs): pass diff --git a/tests/structs/test_swarmnetwork.py b/tests/structs/test_swarmnetwork.py index 6b76cd04..9dc6d903 100644 --- a/tests/structs/test_swarmnetwork.py +++ b/tests/structs/test_swarmnetwork.py @@ -20,7 +20,9 @@ def test_swarm_network_init(swarm_network): @patch("swarms.structs.swarm_net.SwarmNetwork.logger") def test_run(mock_logger, swarm_network): swarm_network.run() - assert (mock_logger.info.call_count == 10) # 2 log messages per agent + assert ( + mock_logger.info.call_count == 10 + ) # 2 log messages per agent def test_run_with_mocked_agents(mocker, swarm_network): diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py index 20773ea7..de0352af 100644 --- a/tests/structs/test_task.py +++ b/tests/structs/test_task.py @@ -7,7 +7,8 @@ from dotenv import load_dotenv from swarms.models.gpt4_vision_api import GPT4VisionAPI from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( - MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,) + MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, +) from swarms.structs.agent import Agent from swarms.structs.task import Task @@ -20,11 +21,13 @@ def llm(): def test_agent_run_task(llm): - task = ("Analyze this image of an assembly line and identify any" - " issues such as misaligned parts, defects, or deviations" - " from the standard assembly process. IF there is anything" - " unsafe in the image, explain why it is unsafe and how it" - " could be improved.") + task = ( + "Analyze this image of an assembly line and identify any" + " issues such as misaligned parts, defects, or deviations" + " from the standard assembly process. IF there is anything" + " unsafe in the image, explain why it is unsafe and how it" + " could be improved." + ) img = "assembly_line.jpg" agent = Agent( @@ -46,7 +49,9 @@ def test_agent_run_task(llm): @pytest.fixture def task(): agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)] - return Task(id="Task_1", task="Task_Name", agents=agents, dependencies=[]) + return Task( + id="Task_1", task="Task_Name", agents=agents, dependencies=[] + ) # Basic tests @@ -184,7 +189,9 @@ def test_task_execute_with_condition(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" condition = mocker.Mock(return_value=True) - task = Task(description="Test task", agent=mock_agent, condition=condition) + task = Task( + description="Test task", agent=mock_agent, condition=condition + ) task.execute() assert task.result == "result" assert task.history == ["result"] @@ -194,7 +201,9 @@ def test_task_execute_with_condition_false(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" condition = mocker.Mock(return_value=False) - task = Task(description="Test task", agent=mock_agent, condition=condition) + task = Task( + description="Test task", agent=mock_agent, condition=condition + ) task.execute() assert task.result is None assert task.history == [] @@ -204,7 +213,9 @@ def test_task_execute_with_action(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" action = mocker.Mock() - task = Task(description="Test task", agent=mock_agent, action=action) + task = Task( + description="Test task", agent=mock_agent, action=action + ) task.execute() assert task.result == "result" assert task.history == ["result"] @@ -232,9 +243,11 @@ def test_task_handle_scheduled_task_future(mocker): agent=mock_agent, schedule_time=datetime.now() + timedelta(days=1), ) - with mocker.patch.object(task.scheduler, - "enter") as mock_enter, mocker.patch.object( - task.scheduler, "run") as mock_run: + with mocker.patch.object( + task.scheduler, "enter" + ) as mock_enter, mocker.patch.object( + task.scheduler, "run" + ) as mock_run: task.handle_scheduled_task() mock_enter.assert_called_once() mock_run.assert_called_once() diff --git a/tests/structs/test_taskqueuebase.py b/tests/structs/test_taskqueuebase.py index e6648881..512f72ae 100644 --- a/tests/structs/test_taskqueuebase.py +++ b/tests/structs/test_taskqueuebase.py @@ -21,9 +21,7 @@ def agent(): @pytest.fixture() def concrete_task_queue(): - class ConcreteTaskQueue(TaskQueueBase): - def add_task(self, task): pass # Here you would add concrete implementation of add_task @@ -53,8 +51,9 @@ def test_add_task_failure(concrete_task_queue, task): # Assuming the task is somehow invalid # Note: Concrete implementation requires logic defining what an invalid task is concrete_task_queue.add_task(task) - assert (concrete_task_queue.add_task(task) - is False) # Adding the same task again + assert ( + concrete_task_queue.add_task(task) is False + ) # Adding the same task again @pytest.mark.parametrize("invalid_task", [None, "", {}, []]) diff --git a/tests/structs/test_team.py b/tests/structs/test_team.py index 92778134..44d64e18 100644 --- a/tests/structs/test_team.py +++ b/tests/structs/test_team.py @@ -7,7 +7,6 @@ from swarms.structs.team import Team class TestTeam(unittest.TestCase): - def setUp(self): self.agent = Agent( llm=OpenAIChat(openai_api_key=""), @@ -31,17 +30,16 @@ class TestTeam(unittest.TestCase): with self.assertRaises(ValueError): self.team.check_config( - {"config": json.dumps({ - "agents": [], - "tasks": [] - })}) + {"config": json.dumps({"agents": [], "tasks": []})} + ) def test_run(self): self.assertEqual(self.team.run(), self.task.execute()) def test_sequential_loop(self): - self.assertEqual(self.team._Team__sequential_loop(), - self.task.execute()) + self.assertEqual( + self.team._Team__sequential_loop(), self.task.execute() + ) def test_log(self): self.assertIsNone(self.team._Team__log("Test message")) diff --git a/tests/structs/test_tests_graph_workflow.py b/tests/structs/test_tests_graph_workflow.py index b35ff6f3..cb5b17a7 100644 --- a/tests/structs/test_tests_graph_workflow.py +++ b/tests/structs/test_tests_graph_workflow.py @@ -27,7 +27,9 @@ def test_set_entry_point(graph_workflow): def test_set_entry_point_nonexistent_node(graph_workflow): - with pytest.raises(ValueError, match="Node does not exist in graph"): + with pytest.raises( + ValueError, match="Node does not exist in graph" + ): graph_workflow.set_entry_point("nonexistent") @@ -40,23 +42,29 @@ def test_add_edge(graph_workflow): def test_add_edge_nonexistent_node(graph_workflow): graph_workflow.add("node1", "value1") - with pytest.raises(ValueError, match="Node does not exist in graph"): + with pytest.raises( + ValueError, match="Node does not exist in graph" + ): graph_workflow.add_edge("node1", "nonexistent") def test_add_conditional_edges(graph_workflow): graph_workflow.add("node1", "value1") graph_workflow.add("node2", "value2") - graph_workflow.add_conditional_edges("node1", "condition1", - {"condition_value1": "node2"}) + graph_workflow.add_conditional_edges( + "node1", "condition1", {"condition_value1": "node2"} + ) assert "node2" in graph_workflow.graph["node1"]["edges"] def test_add_conditional_edges_nonexistent_node(graph_workflow): graph_workflow.add("node1", "value1") - with pytest.raises(ValueError, match="Node does not exist in graph"): + with pytest.raises( + ValueError, match="Node does not exist in graph" + ): graph_workflow.add_conditional_edges( - "node1", "condition1", {"condition_value1": "nonexistent"}) + "node1", "condition1", {"condition_value1": "nonexistent"} + ) def test_run(graph_workflow): diff --git a/tests/telemetry/test_posthog_utils.py b/tests/telemetry/test_posthog_utils.py index dd468d1d..0364cb3a 100644 --- a/tests/telemetry/test_posthog_utils.py +++ b/tests/telemetry/test_posthog_utils.py @@ -35,8 +35,9 @@ def test_log_activity_posthog(mock_posthog, mock_env): test_function() # Check if the Posthog capture method was called with the expected arguments - mock_posthog.capture.assert_called_once_with("test_user_id", event_name, - event_properties) + mock_posthog.capture.assert_called_once_with( + "test_user_id", event_name, event_properties + ) # Test a scenario where environment variables are not set diff --git a/tests/telemetry/test_user_utils.py b/tests/telemetry/test_user_utils.py index 5fc300e2..c7b5962c 100644 --- a/tests/telemetry/test_user_utils.py +++ b/tests/telemetry/test_user_utils.py @@ -46,7 +46,9 @@ def test_generate_unique_identifier(): # Generate unique identifiers and ensure they are valid UUID strings unique_id = generate_unique_identifier() assert isinstance(unique_id, str) - assert uuid.UUID(unique_id, version=5, namespace=uuid.NAMESPACE_DNS) + assert uuid.UUID( + unique_id, version=5, namespace=uuid.NAMESPACE_DNS + ) def test_generate_user_id_edge_case(): @@ -71,7 +73,9 @@ def test_get_system_info_edge_case(): # Test get_system_info for consistency system_info1 = get_system_info() system_info2 = get_system_info() - assert (system_info1 == system_info2) # Ensure system info remains the same + assert ( + system_info1 == system_info2 + ) # Ensure system info remains the same def test_generate_unique_identifier_edge_case(): diff --git a/tests/test_upload_tests_to_issues.py b/tests/test_upload_tests_to_issues.py index c916383e..0857c58a 100644 --- a/tests/test_upload_tests_to_issues.py +++ b/tests/test_upload_tests_to_issues.py @@ -20,7 +20,9 @@ headers = { def run_pytest(): - result = subprocess.run(["pytest"], capture_output=True, text=True) + result = subprocess.run( + ["pytest"], capture_output=True, text=True + ) return result.stdout + result.stderr @@ -54,7 +56,9 @@ def main(): errors = parse_pytest_output(pytest_output) for error in errors: - issue_response = create_github_issue(error["title"], error["body"]) + issue_response = create_github_issue( + error["title"], error["body"] + ) print(f"Issue created: {issue_response.get('html_url')}") diff --git a/tests/tokenizers/test_anthropictokenizer.py b/tests/tokenizers/test_anthropictokenizer.py index a0664bde..14b2fd86 100644 --- a/tests/tokenizers/test_anthropictokenizer.py +++ b/tests/tokenizers/test_anthropictokenizer.py @@ -16,8 +16,9 @@ def test_default_max_tokens(): assert tokenizer.default_max_tokens() == 100000 -@pytest.mark.parametrize("model,tokens", [("claude-2.1", 200000), - ("claude", 100000)]) +@pytest.mark.parametrize( + "model,tokens", [("claude-2.1", 200000), ("claude", 100000)] +) def test_default_max_tokens_models(model, tokens): tokenizer = AnthropicTokenizer(model=model) assert tokenizer.default_max_tokens() == tokens diff --git a/tests/tokenizers/test_basetokenizer.py b/tests/tokenizers/test_basetokenizer.py index c41a0545..3956d2de 100644 --- a/tests/tokenizers/test_basetokenizer.py +++ b/tests/tokenizers/test_basetokenizer.py @@ -18,7 +18,9 @@ def test_post_init(base_tokenizer): # 3. Tests for count_tokens_left with different inputs. -def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch): +def test_count_tokens_left_with_positive_diff( + base_tokenizer, monkeypatch +): # Mocking count_tokens to return a specific value monkeypatch.setattr( "swarms.tokenizers.BaseTokenizer.count_tokens", @@ -27,7 +29,9 @@ def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch): assert base_tokenizer.count_tokens_left("some text") == 50 -def test_count_tokens_left_with_zero_diff(base_tokenizer, monkeypatch): +def test_count_tokens_left_with_zero_diff( + base_tokenizer, monkeypatch +): monkeypatch.setattr( "swarms.tokenizers.BaseTokenizer.count_tokens", lambda x, y: 100, diff --git a/tests/tokenizers/test_huggingfacetokenizer.py b/tests/tokenizers/test_huggingfacetokenizer.py index dad69c1f..1eedb6e5 100644 --- a/tests/tokenizers/test_huggingfacetokenizer.py +++ b/tests/tokenizers/test_huggingfacetokenizer.py @@ -51,10 +51,18 @@ def test_prefix_space_tokens(hftokenizer): # testing _maybe_add_prefix_space method def test__maybe_add_prefix_space(hftokenizer): - assert (hftokenizer._maybe_add_prefix_space( - [101, 2003, 2010, 2050, 2001, 2339], " is why") == " is why") - assert (hftokenizer._maybe_add_prefix_space([2003, 2010, 2050, 2001, 2339], - "is why") == " is why") + assert ( + hftokenizer._maybe_add_prefix_space( + [101, 2003, 2010, 2050, 2001, 2339], " is why" + ) + == " is why" + ) + assert ( + hftokenizer._maybe_add_prefix_space( + [2003, 2010, 2050, 2001, 2339], "is why" + ) + == " is why" + ) # continuing tests for other methods... diff --git a/tests/tokenizers/test_openaitokenizer.py b/tests/tokenizers/test_openaitokenizer.py index 8d987164..3c24748d 100644 --- a/tests/tokenizers/test_openaitokenizer.py +++ b/tests/tokenizers/test_openaitokenizer.py @@ -18,21 +18,31 @@ def test_default_max_tokens(openai_tokenizer): assert openai_tokenizer.default_max_tokens() == 4096 -@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3), - (["Hello"], 4)]) +@pytest.mark.parametrize( + "text, expected_output", [("Hello, world!", 3), (["Hello"], 4)] +) def test_count_tokens_single(openai_tokenizer, text, expected_output): - assert (openai_tokenizer.count_tokens(text, "gpt-3") == expected_output) + assert ( + openai_tokenizer.count_tokens(text, "gpt-3") + == expected_output + ) @pytest.mark.parametrize( "texts, expected_output", [(["Hello, world!", "This is a test"], 6), (["Hello"], 4)], ) -def test_count_tokens_multiple(openai_tokenizer, texts, expected_output): - assert (openai_tokenizer.count_tokens(texts, "gpt-3") == expected_output) +def test_count_tokens_multiple( + openai_tokenizer, texts, expected_output +): + assert ( + openai_tokenizer.count_tokens(texts, "gpt-3") + == expected_output + ) -@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3), - (["Hello"], 4)]) +@pytest.mark.parametrize( + "text, expected_output", [("Hello, world!", 3), (["Hello"], 4)] +) def test_len(openai_tokenizer, text, expected_output): assert openai_tokenizer.len(text, "gpt-3") == expected_output diff --git a/tests/tokenizers/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py index ecd85097..b868f0a1 100644 --- a/tests/tokenizers/test_tokenizer.py +++ b/tests/tokenizers/test_tokenizer.py @@ -7,7 +7,9 @@ from swarms.tokenizers.r_tokenizers import Tokenizer def test_initializer_existing_model_file(): with patch("os.path.exists", return_value=True): - with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.SentencePieceTokenizer" + ) as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") mock_model.assert_called_with("tokenizers/my_model.model") assert tokenizer.model == mock_model.return_value @@ -15,43 +17,66 @@ def test_initializer_existing_model_file(): def test_initializer_model_folder(): with patch("os.path.exists", side_effect=[False, True]): - with patch("swarms.tokenizers.HuggingFaceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.HuggingFaceTokenizer" + ) as mock_model: tokenizer = Tokenizer("my_model_directory") mock_model.assert_called_with("my_model_directory") assert tokenizer.model == mock_model.return_value def test_vocab_size(): - with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.SentencePieceTokenizer" + ) as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert (tokenizer.vocab_size == mock_model.return_value.vocab_size) + assert ( + tokenizer.vocab_size == mock_model.return_value.vocab_size + ) def test_bos_token_id(): - with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.SentencePieceTokenizer" + ) as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert (tokenizer.bos_token_id == mock_model.return_value.bos_token_id) + assert ( + tokenizer.bos_token_id + == mock_model.return_value.bos_token_id + ) def test_encode(): - with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.SentencePieceTokenizer" + ) as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert (tokenizer.encode("hello") == - mock_model.return_value.encode.return_value) + assert ( + tokenizer.encode("hello") + == mock_model.return_value.encode.return_value + ) def test_decode(): - with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.SentencePieceTokenizer" + ) as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") - assert (tokenizer.decode( - [1, 2, 3]) == mock_model.return_value.decode.return_value) + assert ( + tokenizer.decode([1, 2, 3]) + == mock_model.return_value.decode.return_value + ) def test_call(): - with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: + with patch( + "swarms.tokenizers.SentencePieceTokenizer" + ) as mock_model: tokenizer = Tokenizer("tokenizers/my_model.model") assert ( - tokenizer("hello") == mock_model.return_value.__call__.return_value) + tokenizer("hello") + == mock_model.return_value.__call__.return_value + ) # More tests can be added here diff --git a/tests/tools/test_tools_base.py b/tests/tools/test_tools_base.py index dea29d9c..9060f53f 100644 --- a/tests/tools/test_tools_base.py +++ b/tests/tools/test_tools_base.py @@ -65,20 +65,23 @@ def test_structured_tool_invoke(): def test_tool_creation(): - tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") + tool = Tool( + name="test_tool", func=lambda x: x, description="Test tool" + ) assert tool.name == "test_tool" assert tool.func is not None assert tool.description == "Test tool" def test_tool_ainvoke(): - tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") + tool = Tool( + name="test_tool", func=lambda x: x, description="Test tool" + ) result = tool.ainvoke("input_data") assert result == "input_data" def test_tool_ainvoke_with_coroutine(): - async def async_function(input_data): return input_data @@ -92,7 +95,6 @@ def test_tool_ainvoke_with_coroutine(): def test_tool_args(): - def sample_function(input_data): return input_data @@ -108,7 +110,6 @@ def test_tool_args(): def test_structured_tool_creation(): - class SampleArgsSchema: pass @@ -125,7 +126,6 @@ def test_structured_tool_creation(): def test_structured_tool_ainvoke(): - class SampleArgsSchema: pass @@ -140,7 +140,6 @@ def test_structured_tool_ainvoke(): def test_structured_tool_ainvoke_with_coroutine(): - class SampleArgsSchema: pass @@ -158,7 +157,6 @@ def test_structured_tool_ainvoke_with_coroutine(): def test_structured_tool_args(): - class SampleArgsSchema: pass @@ -184,13 +182,14 @@ def test_tool_ainvoke_exception(): def test_tool_ainvoke_with_coroutine_exception(): - tool = Tool(name="test_tool", coroutine=None, description="Test tool") + tool = Tool( + name="test_tool", coroutine=None, description="Test tool" + ) with pytest.raises(NotImplementedError): tool.ainvoke("input_data") def test_structured_tool_ainvoke_exception(): - class SampleArgsSchema: pass @@ -205,7 +204,6 @@ def test_structured_tool_ainvoke_exception(): def test_structured_tool_ainvoke_with_coroutine_exception(): - class SampleArgsSchema: pass @@ -227,7 +225,6 @@ def test_tool_description_not_provided(): def test_tool_invoke_with_callbacks(): - def sample_function(input_data, callbacks=None): if callbacks: callbacks.on_start() @@ -243,7 +240,6 @@ def test_tool_invoke_with_callbacks(): def test_tool_invoke_with_new_argument(): - def sample_function(input_data, callbacks=None): return input_data @@ -253,7 +249,6 @@ def test_tool_invoke_with_new_argument(): def test_tool_ainvoke_with_new_argument(): - async def async_function(input_data, callbacks=None): return input_data @@ -263,7 +258,6 @@ def test_tool_ainvoke_with_new_argument(): def test_tool_description_from_docstring(): - def sample_function(input_data): """Sample function docstring""" return input_data @@ -273,7 +267,6 @@ def test_tool_description_from_docstring(): def test_tool_ainvoke_with_exceptions(): - async def async_function(input_data): raise ValueError("Test exception") @@ -286,7 +279,6 @@ def test_tool_ainvoke_with_exceptions(): def test_structured_tool_infer_schema_false(): - def sample_function(input_data): return input_data @@ -300,7 +292,6 @@ def test_structured_tool_infer_schema_false(): def test_structured_tool_ainvoke_with_callbacks(): - class SampleArgsSchema: pass @@ -316,14 +307,15 @@ def test_structured_tool_ainvoke_with_callbacks(): args_schema=SampleArgsSchema, ) callbacks = MagicMock() - result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks) + result = tool.ainvoke( + {"tool_input": "input_data"}, callbacks=callbacks + ) assert result == "input_data" callbacks.on_start.assert_called_once() callbacks.on_finish.assert_called_once() def test_structured_tool_description_not_provided(): - class SampleArgsSchema: pass @@ -338,7 +330,6 @@ def test_structured_tool_description_not_provided(): def test_structured_tool_args_schema(): - class SampleArgsSchema: pass @@ -354,7 +345,6 @@ def test_structured_tool_args_schema(): def test_structured_tool_args_schema_inference(): - def sample_function(input_data): return input_data @@ -368,7 +358,6 @@ def test_structured_tool_args_schema_inference(): def test_structured_tool_ainvoke_with_new_argument(): - class SampleArgsSchema: pass @@ -380,12 +369,13 @@ def test_structured_tool_ainvoke_with_new_argument(): func=sample_function, args_schema=SampleArgsSchema, ) - result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None) + result = tool.ainvoke( + {"tool_input": "input_data"}, callbacks=None + ) assert result == "input_data" def test_structured_tool_ainvoke_with_exceptions(): - class SampleArgsSchema: pass @@ -471,7 +461,9 @@ def test_tool_with_runnable(mock_runnable): def test_tool_with_invalid_argument(): # Test passing an invalid argument type with pytest.raises(ValueError): - tool(123) # Using an integer instead of a string/callable/Runnable + tool( + 123 + ) # Using an integer instead of a string/callable/Runnable def test_tool_with_multiple_arguments(mock_func): @@ -533,7 +525,9 @@ class MockSchema(BaseModel): # Test suite starts here class TestTool: # Basic Functionality Tests - def test_tool_with_valid_callable_creates_base_tool(self, mock_func): + def test_tool_with_valid_callable_creates_base_tool( + self, mock_func + ): result = tool(mock_func) assert isinstance(result, BaseTool) @@ -566,7 +560,6 @@ class TestTool: # Error Handling Tests def test_tool_raises_error_without_docstring(self): - def no_doc_func(arg: str) -> str: return arg @@ -574,14 +567,14 @@ class TestTool: tool(no_doc_func) def test_tool_raises_error_runnable_without_object_schema( - self, mock_runnable): + self, mock_runnable + ): with pytest.raises(ValueError): tool(mock_runnable) # Decorator Behavior Tests @pytest.mark.asyncio async def test_async_tool_function(self): - @tool async def async_func(arg: str) -> str: return arg @@ -604,7 +597,6 @@ class TestTool: pass def test_tool_with_different_return_types(self): - @tool def return_int(arg: str) -> int: return int(arg) @@ -623,7 +615,6 @@ class TestTool: # Test with multiple arguments def test_tool_with_multiple_args(self): - @tool def concat_strings(a: str, b: str) -> str: return a + b @@ -633,7 +624,6 @@ class TestTool: # Test handling of optional arguments def test_tool_with_optional_args(self): - @tool def greet(name: str, greeting: str = "Hello") -> str: return f"{greeting} {name}" @@ -643,7 +633,6 @@ class TestTool: # Test with variadic arguments def test_tool_with_variadic_args(self): - @tool def sum_numbers(*numbers: int) -> int: return sum(numbers) @@ -653,7 +642,6 @@ class TestTool: # Test with keyword arguments def test_tool_with_kwargs(self): - @tool def build_query(**kwargs) -> str: return "&".join(f"{k}={v}" for k, v in kwargs.items()) @@ -663,7 +651,6 @@ class TestTool: # Test with mixed types of arguments def test_tool_with_mixed_args(self): - @tool def mixed_args(a: int, b: str, *args, **kwargs) -> str: return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}" @@ -672,7 +659,6 @@ class TestTool: # Test error handling with incorrect types def test_tool_error_with_incorrect_types(self): - @tool def add_numbers(a: int, b: int) -> int: return a + b @@ -682,7 +668,6 @@ class TestTool: # Test with nested tools def test_nested_tools(self): - @tool def inner_tool(arg: str) -> str: return f"Inner {arg}" @@ -694,7 +679,6 @@ class TestTool: assert outer_tool("Test") == "Outer Inner Test" def test_tool_with_global_variable(self): - @tool def access_global(arg: str) -> str: return f"{global_var} {arg}" @@ -717,7 +701,6 @@ class TestTool: # Test with complex data structures def test_tool_with_complex_data_structures(self): - @tool def process_data(data: dict) -> list: return [data[key] for key in sorted(data.keys())] @@ -727,7 +710,6 @@ class TestTool: # Test handling exceptions within the tool function def test_tool_handling_internal_exceptions(self): - @tool def function_that_raises(arg: str): if arg == "error": @@ -740,7 +722,6 @@ class TestTool: # Test with functions returning None def test_tool_with_none_return(self): - @tool def return_none(arg: str): return None @@ -754,9 +735,7 @@ class TestTool: # Test with class methods def test_tool_with_class_method(self): - class MyClass: - @tool def method(self, arg: str) -> str: return f"Method {arg}" @@ -766,15 +745,12 @@ class TestTool: # Test tool function with inheritance def test_tool_with_inheritance(self): - class Parent: - @tool def parent_method(self, arg: str) -> str: return f"Parent {arg}" class Child(Parent): - @tool def child_method(self, arg: str) -> str: return f"Child {arg}" @@ -785,9 +761,7 @@ class TestTool: # Test with decorators stacking def test_tool_with_multiple_decorators(self): - def another_decorator(func): - def wrapper(*args, **kwargs): return f"Decorated {func(*args, **kwargs)}" @@ -813,7 +787,9 @@ class TestTool: def thread_target(): results.append(threaded_function(5)) - threads = [threading.Thread(target=thread_target) for _ in range(10)] + threads = [ + threading.Thread(target=thread_target) for _ in range(10) + ] for t in threads: t.start() for t in threads: @@ -823,7 +799,6 @@ class TestTool: # Test with recursive functions def test_tool_with_recursive_function(self): - @tool def recursive_function(n: int) -> int: if n == 0: diff --git a/tests/utils/test_check_device.py b/tests/utils/test_check_device.py index 1b07ce08..503a3774 100644 --- a/tests/utils/test_check_device.py +++ b/tests/utils/test_check_device.py @@ -19,8 +19,9 @@ def test_check_device_no_cuda(monkeypatch): def test_check_device_cuda_exception(monkeypatch): # Mock torch.cuda.is_available to raise an exception - monkeypatch.setattr(torch.cuda, "is_available", - lambda: 1 / 0) # Raises ZeroDivisionError + monkeypatch.setattr( + torch.cuda, "is_available", lambda: 1 / 0 + ) # Raises ZeroDivisionError result = check_device(log_level=logging.DEBUG) assert result.type == "cpu" @@ -32,8 +33,12 @@ def test_check_device_one_cuda(monkeypatch): # Mock torch.cuda.device_count to return 1 monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 - monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) - monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) + monkeypatch.setattr( + torch.cuda, "memory_allocated", lambda device: 0 + ) + monkeypatch.setattr( + torch.cuda, "memory_reserved", lambda device: 0 + ) result = check_device(log_level=logging.DEBUG) assert len(result) == 1 @@ -47,8 +52,12 @@ def test_check_device_multiple_cuda(monkeypatch): # Mock torch.cuda.device_count to return 4 monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 - monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) - monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) + monkeypatch.setattr( + torch.cuda, "memory_allocated", lambda device: 0 + ) + monkeypatch.setattr( + torch.cuda, "memory_reserved", lambda device: 0 + ) result = check_device(log_level=logging.DEBUG) assert len(result) == 4 diff --git a/tests/utils/test_class_args_wrapper.py b/tests/utils/test_class_args_wrapper.py index 6af5ea0d..99d38b2c 100644 --- a/tests/utils/test_class_args_wrapper.py +++ b/tests/utils/test_class_args_wrapper.py @@ -19,7 +19,8 @@ def test_print_class_parameters_agent(): # Replace with the expected output for Agent class expected_output = ( "Parameter: name, Type: \nParameter: age, Type:" - " ") + " " + ) assert output == expected_output @@ -32,7 +33,9 @@ def test_print_class_parameters_error(): def get_parameters(class_name: str): classes = {"Agent": Agent} if class_name in classes: - return print_class_parameters(classes[class_name], api_format=True) + return print_class_parameters( + classes[class_name], api_format=True + ) else: return {"error": "Class not found"} diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index d17ea08d..9be83be4 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -32,14 +32,18 @@ def test_multiple_gpus_available(mocker): def test_device_properties(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) + mocker.patch( + "torch.cuda.get_device_capability", return_value=(7, 5) + ) mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") + mocker.patch( + "torch.cuda.get_device_name", return_value="Tesla K80" + ) devices = check_device() assert len(devices) == 1 assert str(devices[0]) == "cuda" @@ -48,21 +52,27 @@ def test_device_properties(mocker): def test_memory_threshold(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) + mocker.patch( + "torch.cuda.get_device_capability", return_value=(7, 5) + ) mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) - mocker.patch("torch.cuda.memory_allocated", - return_value=900) # 90% of total memory + mocker.patch( + "torch.cuda.memory_allocated", return_value=900 + ) # 90% of total memory mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") + mocker.patch( + "torch.cuda.get_device_name", return_value="Tesla K80" + ) with pytest.warns( - UserWarning, - match=r"Memory usage for device cuda exceeds threshold", + UserWarning, + match=r"Memory usage for device cuda exceeds threshold", ): devices = check_device( - memory_threshold=0.8) # Set memory threshold to 80% + memory_threshold=0.8 + ) # Set memory threshold to 80% assert len(devices) == 1 assert str(devices[0]) == "cuda" @@ -70,21 +80,27 @@ def test_memory_threshold(mocker): def test_compute_capability_threshold(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch("torch.cuda.get_device_capability", - return_value=(3, 0)) # Compute capability 3.0 + mocker.patch( + "torch.cuda.get_device_capability", return_value=(3, 0) + ) # Compute capability 3.0 mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") + mocker.patch( + "torch.cuda.get_device_name", return_value="Tesla K80" + ) with pytest.warns( - UserWarning, - match=(r"Compute capability for device cuda is below threshold"), + UserWarning, + match=( + r"Compute capability for device cuda is below threshold" + ), ): devices = check_device( - capability_threshold=3.5) # Set compute capability threshold to 3.5 + capability_threshold=3.5 + ) # Set compute capability threshold to 3.5 assert len(devices) == 1 assert str(devices[0]) == "cuda" diff --git a/tests/utils/test_display_markdown_message.py b/tests/utils/test_display_markdown_message.py index 9a9b9327..1b7cadaa 100644 --- a/tests/utils/test_display_markdown_message.py +++ b/tests/utils/test_display_markdown_message.py @@ -14,7 +14,8 @@ def test_basic_message(): with mock.patch.object(Console, "print") as mock_print: display_markdown_message("This is a test") mock_print.assert_called_once_with( - Markdown("This is a test", style="cyan")) + Markdown("This is a test", style="cyan") + ) def test_empty_message(): @@ -30,7 +31,8 @@ def test_colors(color): with mock.patch.object(Console, "print") as mock_print: display_markdown_message("This is a test", color) mock_print.assert_called_once_with( - Markdown("This is a test", style=color)) + Markdown("This is a test", style=color) + ) def test_dash_line(): diff --git a/tests/utils/test_extract_code_from_markdown.py b/tests/utils/test_extract_code_from_markdown.py index 8bbee570..eb1a3e5d 100644 --- a/tests/utils/test_extract_code_from_markdown.py +++ b/tests/utils/test_extract_code_from_markdown.py @@ -22,8 +22,12 @@ def markdown_content_without_code(): """ -def test_extract_code_from_markdown_with_code(markdown_content_with_code,): - extracted_code = extract_code_from_markdown(markdown_content_with_code) +def test_extract_code_from_markdown_with_code( + markdown_content_with_code, +): + extracted_code = extract_code_from_markdown( + markdown_content_with_code + ) assert "def my_func():" in extracted_code assert 'print("This is my function.")' in extracted_code assert "class MyClass:" in extracted_code @@ -31,8 +35,11 @@ def test_extract_code_from_markdown_with_code(markdown_content_with_code,): def test_extract_code_from_markdown_without_code( - markdown_content_without_code,): - extracted_code = extract_code_from_markdown(markdown_content_without_code) + markdown_content_without_code, +): + extracted_code = extract_code_from_markdown( + markdown_content_without_code + ) assert extracted_code == "" diff --git a/tests/utils/test_find_image_path.py b/tests/utils/test_find_image_path.py index 15de0f54..29b1c627 100644 --- a/tests/utils/test_find_image_path.py +++ b/tests/utils/test_find_image_path.py @@ -8,8 +8,12 @@ from swarms.utils import find_image_path def test_find_image_path_no_images(): - assert (find_image_path("This is a test string without any image paths.") - is None) + assert ( + find_image_path( + "This is a test string without any image paths." + ) + is None + ) def test_find_image_path_one_image(): @@ -19,7 +23,9 @@ def test_find_image_path_one_image(): def test_find_image_path_multiple_images(): text = "This string has two image paths: img1.png, and img2.jpg." - assert (find_image_path(text) == "img2.jpg") # Assuming both images exist + assert ( + find_image_path(text) == "img2.jpg" + ) # Assuming both images exist def test_find_image_path_wrong_input(): diff --git a/tests/utils/test_limit_tokens_from_string.py b/tests/utils/test_limit_tokens_from_string.py index cc869dd6..4d68dccb 100644 --- a/tests/utils/test_limit_tokens_from_string.py +++ b/tests/utils/test_limit_tokens_from_string.py @@ -4,11 +4,14 @@ from swarms.utils import limit_tokens_from_string def test_limit_tokens_from_string(): - sentence = ("This is a test sentence. It is used for testing the number" - " of tokens.") + sentence = ( + "This is a test sentence. It is used for testing the number" + " of tokens." + ) limited = limit_tokens_from_string(sentence, limit=5) - assert (len(limited.split()) - <= 5), "The output string has more than 5 tokens." + assert ( + len(limited.split()) <= 5 + ), "The output string has more than 5 tokens." def test_limit_zero_tokens(): @@ -18,21 +21,26 @@ def test_limit_zero_tokens(): def test_negative_token_limit(): - sentence = ("This test will raise an exception when limit is negative.") + sentence = ( + "This test will raise an exception when limit is negative." + ) with pytest.raises(Exception): limit_tokens_from_string(sentence, limit=-1) -@pytest.mark.parametrize("sentence, model", - [("Some sentence", "unavailable-model")]) +@pytest.mark.parametrize( + "sentence, model", [("Some sentence", "unavailable-model")] +) def test_unknown_model(sentence, model): with pytest.raises(Exception): limit_tokens_from_string(sentence, model=model) def test_string_token_limit_exceeded(): - sentence = ("This is a long sentence with more than twenty tokens which" - " is used for testing. It checks whether the function" - " correctly limits the tokens to a specified amount.") + sentence = ( + "This is a long sentence with more than twenty tokens which" + " is used for testing. It checks whether the function" + " correctly limits the tokens to a specified amount." + ) limited = limit_tokens_from_string(sentence, limit=20) assert len(limited.split()) <= 20, "The token limit is exceeded." diff --git a/tests/utils/test_load_model_torch.py b/tests/utils/test_load_model_torch.py index 89bd6960..c2018c6a 100644 --- a/tests/utils/test_load_model_torch.py +++ b/tests/utils/test_load_model_torch.py @@ -6,7 +6,6 @@ from swarms.utils import load_model_torch class DummyModel(nn.Module): - def __init__(self): super().__init__() self.fc = nn.Linear(10, 2) @@ -26,7 +25,9 @@ def test_load_model_torch_success(tmp_path): model_loaded = load_model_torch(model_path, model=DummyModel()) # Check if loaded model has the same architecture - assert isinstance(model_loaded, DummyModel), "Loaded model type mismatch." + assert isinstance( + model_loaded, DummyModel + ), "Loaded model type mismatch." # Test case 2: Test if function raises FileNotFoundError for non-existent file @@ -65,12 +66,13 @@ def test_load_model_torch_device_handling(tmp_path): # Define a device other than default and load the model to the specified device device = torch.device("cpu") - model_loaded = load_model_torch(model_path, - model=DummyModel(), - device=device) + model_loaded = load_model_torch( + model_path, model=DummyModel(), device=device + ) - assert (model_loaded.fc.weight.device == device - ), "Model not loaded to specified device." + assert ( + model_loaded.fc.weight.device == device + ), "Model not loaded to specified device." # Test case 6: Testing for correct handling of '*args' and '**kwargs' @@ -80,14 +82,15 @@ def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path): torch.save(model.state_dict(), model_path) def mock_torch_load(*args, **kwargs): - assert ("pickle_module" - in kwargs), "Keyword arguments not passed to 'torch.load'." + assert ( + "pickle_module" in kwargs + ), "Keyword arguments not passed to 'torch.load'." # Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly monkeypatch.setattr(torch, "load", mock_torch_load) - load_model_torch(model_path, - model=DummyModel(), - pickle_module="dummy_module") + load_model_torch( + model_path, model=DummyModel(), pickle_module="dummy_module" + ) # Test case 7: Test for model loading on CPU if no GPU is available @@ -100,8 +103,9 @@ def test_load_model_torch_cpu(tmp_path): return False # Monkeypatch to simulate no GPU available - pytest.MonkeyPatch.setattr(torch.cuda, "is_available", - mock_torch_cuda_is_available) + pytest.MonkeyPatch.setattr( + torch.cuda, "is_available", mock_torch_cuda_is_available + ) model_loaded = load_model_torch(model_path, model=DummyModel()) # Ensure model is loaded on CPU diff --git a/tests/utils/test_load_models_torch.py b/tests/utils/test_load_models_torch.py index 12b7605c..3f09f411 100644 --- a/tests/utils/test_load_models_torch.py +++ b/tests/utils/test_load_models_torch.py @@ -42,13 +42,15 @@ def test_load_model_torch_model_specified(mocker): mock_model = MagicMock(spec=torch.nn.Module) mocker.patch("torch.load", return_value={"key": "value"}) load_model_torch("model_path", model=mock_model) - mock_model.load_state_dict.assert_called_once_with({"key": "value"}, - strict=True) + mock_model.load_state_dict.assert_called_once_with( + {"key": "value"}, strict=True + ) def test_load_model_torch_model_specified_strict_false(mocker): mock_model = MagicMock(spec=torch.nn.Module) mocker.patch("torch.load", return_value={"key": "value"}) load_model_torch("model_path", model=mock_model, strict=False) - mock_model.load_state_dict.assert_called_once_with({"key": "value"}, - strict=False) + mock_model.load_state_dict.assert_called_once_with( + {"key": "value"}, strict=False + ) diff --git a/tests/utils/test_math_eval.py b/tests/utils/test_math_eval.py index e3ddad9f..ae7ee04c 100644 --- a/tests/utils/test_math_eval.py +++ b/tests/utils/test_math_eval.py @@ -18,7 +18,6 @@ def func2_with_exception(x): def test_same_results_no_exception(caplog): - @math_eval(func1_no_exception, func2_no_exception) def test_func(x): return x @@ -29,7 +28,6 @@ def test_same_results_no_exception(caplog): def test_func1_exception(caplog): - @math_eval(func1_with_exception, func2_no_exception) def test_func(x): return x diff --git a/tests/utils/test_metrics_decorator.py b/tests/utils/test_metrics_decorator.py index 3ad34684..8c3a8af9 100644 --- a/tests/utils/test_metrics_decorator.py +++ b/tests/utils/test_metrics_decorator.py @@ -10,7 +10,6 @@ from swarms.utils import metrics_decorator # Basic successful test def test_metrics_decorator_success(): - @metrics_decorator def decorated_func(): time.sleep(0.1) @@ -31,8 +30,8 @@ def test_metrics_decorator_success(): ], ) def test_metrics_decorator_with_various_wait_times_and_return_vals( - wait_time, return_val): - + wait_time, return_val +): @metrics_decorator def decorated_func(): time.sleep(wait_time) @@ -56,17 +55,19 @@ def test_metrics_decorator_with_mocked_time(mocker): return ["tok_1", "tok_2"] metrics = decorated_func() - assert (metrics == """ + assert ( + metrics + == """ Time to First Token: 5 Generation Latency: 20 Throughput: 0.1 - """) + """ + ) mocked_time.assert_any_call() # Test to ensure that exceptions in the decorated function are propagated def test_metrics_decorator_raises_exception(): - @metrics_decorator def decorated_func(): raise ValueError("Oops!") @@ -77,7 +78,6 @@ def test_metrics_decorator_raises_exception(): # Test to ensure proper handling when decorated function returns non-list value def test_metrics_decorator_with_non_list_return_val(): - @metrics_decorator def decorated_func(): return "Hello, world!" diff --git a/tests/utils/test_pdf_to_text.py b/tests/utils/test_pdf_to_text.py index 115a58b1..257364b4 100644 --- a/tests/utils/test_pdf_to_text.py +++ b/tests/utils/test_pdf_to_text.py @@ -29,8 +29,8 @@ def test_passing_non_pdf_file(tmpdir): file = tmpdir.join("temp.txt") file.write("This is a test") with pytest.raises( - Exception, - match=r"An error occurred while reading the PDF file", + Exception, + match=r"An error occurred while reading the PDF file", ): pdf_to_text(str(file)) diff --git a/tests/utils/test_prep_torch_inference.py b/tests/utils/test_prep_torch_inference.py index 8fcf7666..6af4a9a7 100644 --- a/tests/utils/test_prep_torch_inference.py +++ b/tests/utils/test_prep_torch_inference.py @@ -9,13 +9,16 @@ from swarms.utils import prep_torch_inference def test_prep_torch_inference(): model_path = "model_path" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) model_mock = Mock() model_mock.eval = Mock() # Mocking the load_model_torch function to return our mock model. - with unittest.mock.patch("swarms.utils.load_model_torch", - return_value=model_mock) as _: + with unittest.mock.patch( + "swarms.utils.load_model_torch", return_value=model_mock + ) as _: model = prep_torch_inference(model_path, device) # Check if model was properly loaded and eval function was called diff --git a/tests/utils/test_prep_torch_model_inference.py b/tests/utils/test_prep_torch_model_inference.py index fb93cfb9..07da4e97 100644 --- a/tests/utils/test_prep_torch_model_inference.py +++ b/tests/utils/test_prep_torch_model_inference.py @@ -3,7 +3,8 @@ from unittest.mock import MagicMock import torch from swarms.utils.prep_torch_model_inference import ( - prep_torch_inference,) + prep_torch_inference, +) def test_prep_torch_inference_no_model_path(): diff --git a/tests/utils/test_print_class_parameters.py b/tests/utils/test_print_class_parameters.py index 4fde69e5..9a133ae4 100644 --- a/tests/utils/test_print_class_parameters.py +++ b/tests/utils/test_print_class_parameters.py @@ -4,30 +4,27 @@ from swarms.utils import print_class_parameters class TestObject: - def __init__(self, value1, value2: int): pass class TestObject2: - def __init__(self: "TestObject2", value1, value2: int = 5): pass def test_class_with_complex_parameters(): - class ComplexArgs: - def __init__(self, value1: list, value2: dict = {}): pass output = {"value1": "", "value2": ""} - assert (print_class_parameters(ComplexArgs, api_format=True) == output) + assert ( + print_class_parameters(ComplexArgs, api_format=True) == output + ) def test_empty_class(): - class Empty: pass @@ -36,9 +33,7 @@ def test_empty_class(): def test_class_with_no_annotations(): - class NoAnnotations: - def __init__(self, value1, value2): pass @@ -46,13 +41,14 @@ def test_class_with_no_annotations(): "value1": "", "value2": "", } - assert (print_class_parameters(NoAnnotations, api_format=True) == output) + assert ( + print_class_parameters(NoAnnotations, api_format=True) + == output + ) def test_class_with_partial_annotations(): - class PartialAnnotations: - def __init__(self, value1, value2: int): pass @@ -60,8 +56,10 @@ def test_class_with_partial_annotations(): "value1": "", "value2": "", } - assert (print_class_parameters(PartialAnnotations, - api_format=True) == output) + assert ( + print_class_parameters(PartialAnnotations, api_format=True) + == output + ) @pytest.mark.parametrize( diff --git a/tests/utils/test_subprocess_code_interpreter.py b/tests/utils/test_subprocess_code_interpreter.py index 8522a59a..3bb800f5 100644 --- a/tests/utils/test_subprocess_code_interpreter.py +++ b/tests/utils/test_subprocess_code_interpreter.py @@ -5,7 +5,8 @@ import threading import pytest from swarms.utils.code_interpreter import ( # Adjust the import according to your project structure - SubprocessCodeInterpreter,) + SubprocessCodeInterpreter, +) # Fixture for the SubprocessCodeInterpreter instance @@ -29,8 +30,9 @@ def test_start_and_terminate_process(interpreter): interpreter.start_process() assert isinstance(interpreter.process, subprocess.Popen) interpreter.terminate() - assert (interpreter.process.poll() - is not None) # Process should be terminated + assert ( + interpreter.process.poll() is not None + ) # Process should be terminated # Test preprocess_code method @@ -44,22 +46,25 @@ def test_preprocess_code(interpreter): # Test detect_active_line method def test_detect_active_line(interpreter): line = "Some line of code" - assert (interpreter.detect_active_line(line) - is None) # Adjust assertion based on implementation + assert ( + interpreter.detect_active_line(line) is None + ) # Adjust assertion based on implementation # Test detect_end_of_execution method def test_detect_end_of_execution(interpreter): line = "End of execution line" - assert (interpreter.detect_end_of_execution(line) - is None) # Adjust assertion based on implementation + assert ( + interpreter.detect_end_of_execution(line) is None + ) # Adjust assertion based on implementation # Test line_postprocessor method def test_line_postprocessor(interpreter): line = "Some output line" - assert (interpreter.line_postprocessor(line) == line - ) # Adjust assertion based on implementation + assert ( + interpreter.line_postprocessor(line) == line + ) # Adjust assertion based on implementation # Test handle_stream_output method diff --git a/tests/utils/test_try_except_wrapper.py b/tests/utils/test_try_except_wrapper.py index d815988e..26b509fb 100644 --- a/tests/utils/test_try_except_wrapper.py +++ b/tests/utils/test_try_except_wrapper.py @@ -2,44 +2,44 @@ from swarms.utils.try_except_wrapper import try_except_wrapper def test_try_except_wrapper_with_no_exception(): - @try_except_wrapper def add(x, y): return x + y result = add(1, 2) - assert (result == 3), "The function should return the sum of the arguments" + assert ( + result == 3 + ), "The function should return the sum of the arguments" def test_try_except_wrapper_with_exception(): - @try_except_wrapper def divide(x, y): return x / y result = divide(1, 0) assert ( - result - is None), "The function should return None when an exception is raised" + result is None + ), "The function should return None when an exception is raised" def test_try_except_wrapper_with_multiple_arguments(): - @try_except_wrapper def concatenate(*args): return "".join(args) result = concatenate("Hello", " ", "world") - assert (result == "Hello world" - ), "The function should concatenate the arguments" + assert ( + result == "Hello world" + ), "The function should concatenate the arguments" def test_try_except_wrapper_with_keyword_arguments(): - @try_except_wrapper def greet(name="world"): return f"Hello, {name}" result = greet(name="Alice") - assert (result == "Hello, Alice" - ), "The function should use the keyword arguments" + assert ( + result == "Hello, Alice" + ), "The function should use the keyword arguments"