run pre-commit code_quality check

pull/408/head
Joshua David 1 year ago
parent 36c295ed8b
commit 5b121a7445

2
.gitignore vendored

@ -110,7 +110,7 @@ docs/_build/
# PyBuilder # PyBuilder
.pybuilder/ .pybuilder/
target/ target/
`
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints

@ -33,7 +33,7 @@ CODE
""" """
# Initialize the language model # Initialize the language model
llm = OpenAIChat(openai_api_key=api_key, max_tokens=5000) llm = OpenAIChat(openai_api_key=api_key, max_tokens=4096)
# Documentation agent # Documentation agent

@ -3,7 +3,7 @@ from pathlib import Path
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings from langchain_community.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma from langchain_community.vectorstores import Chroma

@ -23,15 +23,19 @@ def test_multion_agent_run(mock_multion):
assert result == "result" assert result == "result"
assert status == "status" assert status == "status"
assert last_url == "lastUrl" assert last_url == "lastUrl"
mock_multion.browse.assert_called_once_with({ mock_multion.browse.assert_called_once_with(
{
"cmd": "task", "cmd": "task",
"url": "https://www.example.com", "url": "https://www.example.com",
"maxSteps": 5, "maxSteps": 5,
}) }
)
# Additional tests for different tasks # Additional tests for different tasks
@pytest.mark.parametrize("task", ["task1", "task2", "task3", "task4", "task5"]) @pytest.mark.parametrize(
"task", ["task1", "task2", "task3", "task4", "task5"]
)
@patch("swarms.agents.multion_agent.multion") @patch("swarms.agents.multion_agent.multion")
def test_multion_agent_run_different_tasks(mock_multion, task): def test_multion_agent_run_different_tasks(mock_multion, task):
mock_response = MagicMock() mock_response = MagicMock()
@ -50,8 +54,6 @@ def test_multion_agent_run_different_tasks(mock_multion, task):
assert result == "result" assert result == "result"
assert status == "status" assert status == "status"
assert last_url == "lastUrl" assert last_url == "lastUrl"
mock_multion.browse.assert_called_once_with({ mock_multion.browse.assert_called_once_with(
"cmd": task, {"cmd": task, "url": "https://www.example.com", "maxSteps": 5}
"url": "https://www.example.com", )
"maxSteps": 5
})

@ -11,27 +11,18 @@ def test_tool_agent_init():
json_schema = { json_schema = {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "name": {"type": "string"},
"type": "string" "age": {"type": "number"},
}, "is_student": {"type": "boolean"},
"age": { "courses": {"type": "array", "items": {"type": "string"}},
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
}, },
} }
name = "Test Agent" name = "Test Agent"
description = "This is a test agent" description = "This is a test agent"
agent = ToolAgent(name, description, model, tokenizer, json_schema) agent = ToolAgent(
name, description, model, tokenizer, json_schema
)
assert agent.name == name assert agent.name == name
assert agent.description == description assert agent.description == description
@ -47,29 +38,22 @@ def test_tool_agent_run(mock_run):
json_schema = { json_schema = {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "name": {"type": "string"},
"type": "string" "age": {"type": "number"},
}, "is_student": {"type": "boolean"},
"age": { "courses": {"type": "array", "items": {"type": "string"}},
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
}, },
} }
name = "Test Agent" name = "Test Agent"
description = "This is a test agent" description = "This is a test agent"
task = ("Generate a person's information based on the following" task = (
" schema:") "Generate a person's information based on the following"
" schema:"
)
agent = ToolAgent(name, description, model, tokenizer, json_schema) agent = ToolAgent(
name, description, model, tokenizer, json_schema
)
agent.run(task) agent.run(task)
mock_run.assert_called_once_with(task) mock_run.assert_called_once_with(task)
@ -81,21 +65,10 @@ def test_tool_agent_init_with_kwargs():
json_schema = { json_schema = {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "name": {"type": "string"},
"type": "string" "age": {"type": "number"},
}, "is_student": {"type": "boolean"},
"age": { "courses": {"type": "array", "items": {"type": "string"}},
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
}, },
} }
name = "Test Agent" name = "Test Agent"
@ -109,8 +82,9 @@ def test_tool_agent_init_with_kwargs():
"max_string_token_length": 20, "max_string_token_length": 20,
} }
agent = ToolAgent(name, description, model, tokenizer, json_schema, agent = ToolAgent(
**kwargs) name, description, model, tokenizer, json_schema, **kwargs
)
assert agent.name == name assert agent.name == name
assert agent.description == description assert agent.description == description
@ -121,4 +95,7 @@ def test_tool_agent_init_with_kwargs():
assert agent.max_array_length == kwargs["max_array_length"] assert agent.max_array_length == kwargs["max_array_length"]
assert agent.max_number_tokens == kwargs["max_number_tokens"] assert agent.max_number_tokens == kwargs["max_number_tokens"]
assert agent.temperature == kwargs["temperature"] assert agent.temperature == kwargs["temperature"]
assert (agent.max_string_token_length == kwargs["max_string_token_length"]) assert (
agent.max_string_token_length
== kwargs["max_string_token_length"]
)

@ -33,8 +33,9 @@ def test_memory_limit_enforced(memory):
# Parameterized Tests # Parameterized Tests
@pytest.mark.parametrize("scores, best_score", [([10, 5, 3], 10), @pytest.mark.parametrize(
([1, 2, 3], 3)]) "scores, best_score", [([10, 5, 3], 10), ([1, 2, 3], 3)]
)
def test_get_top_n(scores, best_score, memory): def test_get_top_n(scores, best_score, memory):
for score in scores: for score in scores:
memory.add(score, {"data": f"test{score}"}) memory.add(score, {"data": f"test{score}"})

@ -26,7 +26,8 @@ def memory_instance(memory_file):
def test_init(memory_file): def test_init(memory_file):
memory = DictSharedMemory(file_loc=memory_file) memory = DictSharedMemory(file_loc=memory_file)
assert os.path.exists( assert os.path.exists(
memory.file_loc), "Memory file should be created if non-existent" memory.file_loc
), "Memory file should be created if non-existent"
def test_add_entry(memory_instance): def test_add_entry(memory_instance):
@ -43,7 +44,8 @@ def test_get_top_n(memory_instance):
memory_instance.add(9.5, "agent123", 1, "Entry A") memory_instance.add(9.5, "agent123", 1, "Entry A")
memory_instance.add(8.5, "agent124", 1, "Entry B") memory_instance.add(8.5, "agent124", 1, "Entry B")
top_1 = memory_instance.get_top_n(1) top_1 = memory_instance.get_top_n(1)
assert (len(top_1) == 1 assert (
len(top_1) == 1
), "get_top_n should return the correct number of top entries" ), "get_top_n should return the correct number of top entries"
@ -57,13 +59,17 @@ def test_get_top_n(memory_instance):
# add more test cases # add more test cases
], ],
) )
def test_parametrized_get_top_n(memory_instance, scores, agent_ids, def test_parametrized_get_top_n(
expected_top_score): memory_instance, scores, agent_ids, expected_top_score
):
for score, agent_id in zip(scores, agent_ids): for score, agent_id in zip(scores, agent_ids):
memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}") memory_instance.add(
score, agent_id, 1, f"Entry by {agent_id}"
)
top_1 = memory_instance.get_top_n(1) top_1 = memory_instance.get_top_n(1)
top_score = next(iter(top_1.values()))["score"] top_score = next(iter(top_1.values()))["score"]
assert (top_score == expected_top_score assert (
top_score == expected_top_score
), "get_top_n should return the entry with top score" ), "get_top_n should return the entry with top score"
@ -72,7 +78,9 @@ def test_parametrized_get_top_n(memory_instance, scores, agent_ids,
def test_add_entry_invalid_input(memory_instance): def test_add_entry_invalid_input(memory_instance):
with pytest.raises(ValueError): with pytest.raises(ValueError):
memory_instance.add("invalid_score", "agent123", 1, "Test Entry") memory_instance.add(
"invalid_score", "agent123", 1, "Test Entry"
)
# Mocks and monkey-patching # Mocks and monkey-patching

@ -35,13 +35,16 @@ def qa_mock():
# Example test cases # Example test cases
def test_initialization_default_settings(vector_memory): def test_initialization_default_settings(vector_memory):
assert vector_memory.chunk_size == 1000 assert vector_memory.chunk_size == 1000
assert (vector_memory.chunk_overlap == 100 assert (
vector_memory.chunk_overlap == 100
) # assuming default overlap of 0.1 ) # assuming default overlap of 0.1
assert vector_memory.loc.exists() assert vector_memory.loc.exists()
def test_add_entry(vector_memory, embeddings_mock): def test_add_entry(vector_memory, embeddings_mock):
with patch.object(vector_memory.db, "add_texts") as add_texts_mock: with patch.object(
vector_memory.db, "add_texts"
) as add_texts_mock:
vector_memory.add("Example text") vector_memory.add("Example text")
add_texts_mock.assert_called() add_texts_mock.assert_called()
@ -74,8 +77,9 @@ def test_ask_question_returns_string(vector_memory, qa_mock):
), # Mocked object as a placeholder ), # Mocked object as a placeholder
], ],
) )
def test_search_memory_different_params(vector_memory, query, k, type, def test_search_memory_different_params(
expected): vector_memory, query, k, type, expected
):
with patch.object( with patch.object(
vector_memory.db, vector_memory.db,
"max_marginal_relevance_search", "max_marginal_relevance_search",
@ -86,5 +90,7 @@ def test_search_memory_different_params(vector_memory, query, k, type,
"similarity_search_with_score", "similarity_search_with_score",
return_value=expected, return_value=expected,
): ):
result = vector_memory.search_memory(query, k=k, type=type) result = vector_memory.search_memory(
query, k=k, type=type
)
assert len(result) == (k if k > 0 else 0) assert len(result) == (k if k > 0 else 0)

@ -8,7 +8,8 @@ api_key = os.getenv("PINECONE_API_KEY") or ""
def test_init(): def test_init():
with patch("pinecone.init") as MockInit, patch( with patch("pinecone.init") as MockInit, patch(
"pinecone.Index") as MockIndex: "pinecone.Index"
) as MockIndex:
store = PineconeDB( store = PineconeDB(
api_key=api_key, api_key=api_key,
index_name="test_index", index_name="test_index",
@ -70,7 +71,8 @@ def test_query():
def test_create_index(): def test_create_index():
with patch("pinecone.init"), patch("pinecone.Index"), patch( with patch("pinecone.init"), patch("pinecone.Index"), patch(
"pinecone.create_index") as MockCreateIndex: "pinecone.create_index"
) as MockCreateIndex:
store = PineconeDB( store = PineconeDB(
api_key=api_key, api_key=api_key,
index_name="test_index", index_name="test_index",

@ -32,7 +32,8 @@ def test_create_vector_model():
def test_add_or_update_vector(): def test_add_or_update_vector():
with patch("sqlalchemy.create_engine"), patch( with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session") as MockSession: "sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB( db = PostgresDB(
connection_string=PSG_CONNECTION_STRING, connection_string=PSG_CONNECTION_STRING,
table_name="test", table_name="test",
@ -50,7 +51,8 @@ def test_add_or_update_vector():
def test_query_vectors(): def test_query_vectors():
with patch("sqlalchemy.create_engine"), patch( with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session") as MockSession: "sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB( db = PostgresDB(
connection_string=PSG_CONNECTION_STRING, connection_string=PSG_CONNECTION_STRING,
table_name="test", table_name="test",
@ -65,7 +67,8 @@ def test_query_vectors():
def test_delete_vector(): def test_delete_vector():
with patch("sqlalchemy.create_engine"), patch( with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session") as MockSession: "sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB( db = PostgresDB(
connection_string=PSG_CONNECTION_STRING, connection_string=PSG_CONNECTION_STRING,
table_name="test", table_name="test",

@ -13,7 +13,8 @@ def mock_qdrant_client():
@pytest.fixture @pytest.fixture
def mock_sentence_transformer(): def mock_sentence_transformer():
with patch("sentence_transformers.SentenceTransformer" with patch(
"sentence_transformers.SentenceTransformer"
) as MockSentenceTransformer: ) as MockSentenceTransformer:
yield MockSentenceTransformer() yield MockSentenceTransformer()
@ -28,7 +29,9 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client):
assert qdrant_client.client is not None assert qdrant_client.client is not None
def test_load_embedding_model(qdrant_client, mock_sentence_transformer): def test_load_embedding_model(
qdrant_client, mock_sentence_transformer
):
qdrant_client._load_embedding_model("model_name") qdrant_client._load_embedding_model("model_name")
mock_sentence_transformer.assert_called_once_with("model_name") mock_sentence_transformer.assert_called_once_with("model_name")
@ -36,7 +39,8 @@ def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
def test_setup_collection(qdrant_client, mock_qdrant_client): def test_setup_collection(qdrant_client, mock_qdrant_client):
qdrant_client._setup_collection() qdrant_client._setup_collection()
mock_qdrant_client.get_collection.assert_called_once_with( mock_qdrant_client.get_collection.assert_called_once_with(
qdrant_client.collection_name) qdrant_client.collection_name
)
def test_add_vectors(qdrant_client, mock_qdrant_client): def test_add_vectors(qdrant_client, mock_qdrant_client):

@ -12,29 +12,26 @@ def test_init():
def test_add(): def test_add():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
assert memory.short_term_memory == [{ assert memory.short_term_memory == [
"role": "user", {"role": "user", "message": "Hello, world!"}
"message": "Hello, world!" ]
}]
def test_get_short_term(): def test_get_short_term():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
assert memory.get_short_term() == [{ assert memory.get_short_term() == [
"role": "user", {"role": "user", "message": "Hello, world!"}
"message": "Hello, world!" ]
}]
def test_get_medium_term(): def test_get_medium_term():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
memory.move_to_medium_term(0) memory.move_to_medium_term(0)
assert memory.get_medium_term() == [{ assert memory.get_medium_term() == [
"role": "user", {"role": "user", "message": "Hello, world!"}
"message": "Hello, world!" ]
}]
def test_clear_medium_term(): def test_clear_medium_term():
@ -48,18 +45,19 @@ def test_clear_medium_term():
def test_get_short_term_memory_str(): def test_get_short_term_memory_str():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
assert (memory.get_short_term_memory_str() == assert (
"[{'role': 'user', 'message': 'Hello, world!'}]") memory.get_short_term_memory_str()
== "[{'role': 'user', 'message': 'Hello, world!'}]"
)
def test_update_short_term(): def test_update_short_term():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
memory.update_short_term(0, "user", "Goodbye, world!") memory.update_short_term(0, "user", "Goodbye, world!")
assert memory.get_short_term() == [{ assert memory.get_short_term() == [
"role": "user", {"role": "user", "message": "Goodbye, world!"}
"message": "Goodbye, world!" ]
}]
def test_clear(): def test_clear():
@ -73,10 +71,9 @@ def test_search_memory():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
assert memory.search_memory("Hello") == { assert memory.search_memory("Hello") == {
"short_term": [(0, { "short_term": [
"role": "user", (0, {"role": "user", "message": "Hello, world!"})
"message": "Hello, world!" ],
})],
"medium_term": [], "medium_term": [],
} }
@ -84,18 +81,19 @@ def test_search_memory():
def test_return_shortmemory_as_str(): def test_return_shortmemory_as_str():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
assert (memory.return_shortmemory_as_str() == assert (
"[{'role': 'user', 'message': 'Hello, world!'}]") memory.return_shortmemory_as_str()
== "[{'role': 'user', 'message': 'Hello, world!'}]"
)
def test_move_to_medium_term(): def test_move_to_medium_term():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
memory.move_to_medium_term(0) memory.move_to_medium_term(0)
assert memory.get_medium_term() == [{ assert memory.get_medium_term() == [
"role": "user", {"role": "user", "message": "Hello, world!"}
"message": "Hello, world!" ]
}]
assert memory.get_short_term() == [] assert memory.get_short_term() == []
@ -103,8 +101,10 @@ def test_return_medium_memory_as_str():
memory = ShortTermMemory() memory = ShortTermMemory()
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
memory.move_to_medium_term(0) memory.move_to_medium_term(0)
assert (memory.return_medium_memory_as_str() == assert (
"[{'role': 'user', 'message': 'Hello, world!'}]") memory.return_medium_memory_as_str()
== "[{'role': 'user', 'message': 'Hello, world!'}]"
)
def test_thread_safety(): def test_thread_safety():
@ -114,7 +114,9 @@ def test_thread_safety():
for _ in range(1000): for _ in range(1000):
memory.add("user", "Hello, world!") memory.add("user", "Hello, world!")
threads = [threading.Thread(target=add_messages) for _ in range(10)] threads = [
threading.Thread(target=add_messages) for _ in range(10)
]
for thread in threads: for thread in threads:
thread.start() thread.start()
for thread in threads: for thread in threads:

@ -8,7 +8,9 @@ from swarms.memory.sqlite import SQLiteDB
@pytest.fixture @pytest.fixture
def db(): def db():
conn = sqlite3.connect(":memory:") conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") conn.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
)
conn.commit() conn.commit()
return SQLiteDB(":memory:") return SQLiteDB(":memory:")
@ -28,7 +30,9 @@ def test_delete(db):
def test_update(db): def test_update(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",)) db.add("INSERT INTO test (name) VALUES (?)", ("test",))
db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test")) db.update(
"UPDATE test SET name = ? WHERE name = ?", ("new", "test")
)
result = db.query("SELECT * FROM test") result = db.query("SELECT * FROM test")
assert result == [(1, "new")] assert result == [(1, "new")]
@ -41,7 +45,9 @@ def test_query(db):
def test_execute_query(db): def test_execute_query(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",)) db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.execute_query("SELECT * FROM test WHERE name = ?", ("test",)) result = db.execute_query(
"SELECT * FROM test WHERE name = ?", ("test",)
)
assert result == [(1, "test")] assert result == [(1, "test")]
@ -95,4 +101,6 @@ def test_query_with_wrong_query(db):
def test_execute_query_with_wrong_query(db): def test_execute_query_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError): with pytest.raises(sqlite3.OperationalError):
db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",)) db.execute_query(
"SELECT * FROM wrong WHERE name = ?", ("test",)
)

@ -16,7 +16,9 @@ def weaviate_client_mock():
grpc_port="mock_grpc_port", grpc_port="mock_grpc_port",
grpc_secure=False, grpc_secure=False,
auth_client_secret="mock_api_key", auth_client_secret="mock_api_key",
additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"}, additional_headers={
"X-OpenAI-Api-Key": "mock_openai_api_key"
},
additional_config=Mock(), additional_config=Mock(),
) )
@ -34,15 +36,13 @@ def weaviate_client_mock():
# Define tests for the WeaviateDB class # Define tests for the WeaviateDB class
def test_create_collection(weaviate_client_mock): def test_create_collection(weaviate_client_mock):
# Test creating a collection # Test creating a collection
weaviate_client_mock.create_collection("test_collection", [{ weaviate_client_mock.create_collection(
"name": "property" "test_collection", [{"name": "property"}]
}]) )
weaviate_client_mock.client.collections.create.assert_called_with( weaviate_client_mock.client.collections.create.assert_called_with(
name="test_collection", name="test_collection",
vectorizer_config=None, vectorizer_config=None,
properties=[{ properties=[{"name": "property"}],
"name": "property"
}],
) )
@ -51,9 +51,11 @@ def test_add_object(weaviate_client_mock):
properties = {"name": "John"} properties = {"name": "John"}
weaviate_client_mock.add("test_collection", properties) weaviate_client_mock.add("test_collection", properties)
weaviate_client_mock.client.collections.get.assert_called_with( weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection") "test_collection"
)
weaviate_client_mock.client.collections.data.insert.assert_called_with( weaviate_client_mock.client.collections.data.insert.assert_called_with(
properties) properties
)
def test_query_objects(weaviate_client_mock): def test_query_objects(weaviate_client_mock):
@ -61,20 +63,26 @@ def test_query_objects(weaviate_client_mock):
query = "name:John" query = "name:John"
weaviate_client_mock.query("test_collection", query) weaviate_client_mock.query("test_collection", query)
weaviate_client_mock.client.collections.get.assert_called_with( weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection") "test_collection"
)
weaviate_client_mock.client.collections.query.bm25.assert_called_with( weaviate_client_mock.client.collections.query.bm25.assert_called_with(
query=query, limit=10) query=query, limit=10
)
def test_update_object(weaviate_client_mock): def test_update_object(weaviate_client_mock):
# Test updating an object # Test updating an object
object_id = "12345" object_id = "12345"
properties = {"name": "Jane"} properties = {"name": "Jane"}
weaviate_client_mock.update("test_collection", object_id, properties) weaviate_client_mock.update(
"test_collection", object_id, properties
)
weaviate_client_mock.client.collections.get.assert_called_with( weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection") "test_collection"
)
weaviate_client_mock.client.collections.data.update.assert_called_with( weaviate_client_mock.client.collections.data.update.assert_called_with(
object_id, properties) object_id, properties
)
def test_delete_object(weaviate_client_mock): def test_delete_object(weaviate_client_mock):
@ -82,23 +90,25 @@ def test_delete_object(weaviate_client_mock):
object_id = "12345" object_id = "12345"
weaviate_client_mock.delete("test_collection", object_id) weaviate_client_mock.delete("test_collection", object_id)
weaviate_client_mock.client.collections.get.assert_called_with( weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection") "test_collection"
)
weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with( weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with(
object_id) object_id
)
def test_create_collection_with_vectorizer_config(weaviate_client_mock,): def test_create_collection_with_vectorizer_config(
weaviate_client_mock,
):
# Test creating a collection with vectorizer configuration # Test creating a collection with vectorizer configuration
vectorizer_config = {"config_key": "config_value"} vectorizer_config = {"config_key": "config_value"}
weaviate_client_mock.create_collection("test_collection", [{ weaviate_client_mock.create_collection(
"name": "property" "test_collection", [{"name": "property"}], vectorizer_config
}], vectorizer_config) )
weaviate_client_mock.client.collections.create.assert_called_with( weaviate_client_mock.client.collections.create.assert_called_with(
name="test_collection", name="test_collection",
vectorizer_config=vectorizer_config, vectorizer_config=vectorizer_config,
properties=[{ properties=[{"name": "property"}],
"name": "property"
}],
) )
@ -108,9 +118,11 @@ def test_query_objects_with_limit(weaviate_client_mock):
limit = 20 limit = 20
weaviate_client_mock.query("test_collection", query, limit) weaviate_client_mock.query("test_collection", query, limit)
weaviate_client_mock.client.collections.get.assert_called_with( weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection") "test_collection"
)
weaviate_client_mock.client.collections.query.bm25.assert_called_with( weaviate_client_mock.client.collections.query.bm25.assert_called_with(
query=query, limit=limit) query=query, limit=limit
)
def test_query_objects_without_limit(weaviate_client_mock): def test_query_objects_without_limit(weaviate_client_mock):
@ -118,9 +130,11 @@ def test_query_objects_without_limit(weaviate_client_mock):
query = "name:John" query = "name:John"
weaviate_client_mock.query("test_collection", query) weaviate_client_mock.query("test_collection", query)
weaviate_client_mock.client.collections.get.assert_called_with( weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection") "test_collection"
)
weaviate_client_mock.client.collections.query.bm25.assert_called_with( weaviate_client_mock.client.collections.query.bm25.assert_called_with(
query=query, limit=10) query=query, limit=10
)
def test_create_collection_failure(weaviate_client_mock): def test_create_collection_failure(weaviate_client_mock):
@ -129,10 +143,12 @@ def test_create_collection_failure(weaviate_client_mock):
"weaviate_client.weaviate.collections.create", "weaviate_client.weaviate.collections.create",
side_effect=Exception("Create error"), side_effect=Exception("Create error"),
): ):
with pytest.raises(Exception, match="Error creating collection"): with pytest.raises(
weaviate_client_mock.create_collection("test_collection", [{ Exception, match="Error creating collection"
"name": "property" ):
}]) weaviate_client_mock.create_collection(
"test_collection", [{"name": "property"}]
)
def test_add_object_failure(weaviate_client_mock): def test_add_object_failure(weaviate_client_mock):
@ -166,8 +182,9 @@ def test_update_object_failure(weaviate_client_mock):
side_effect=Exception("Update error"), side_effect=Exception("Update error"),
): ):
with pytest.raises(Exception, match="Error updating object"): with pytest.raises(Exception, match="Error updating object"):
weaviate_client_mock.update("test_collection", object_id, weaviate_client_mock.update(
properties) "test_collection", object_id, properties
)
def test_delete_object_failure(weaviate_client_mock): def test_delete_object_failure(weaviate_client_mock):

@ -8,11 +8,12 @@ from swarms.models.anthropic import Anthropic
# Mock the Anthropic API client for testing # Mock the Anthropic API client for testing
class MockAnthropicClient: class MockAnthropicClient:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
def completions_create(self, prompt, stop_sequences, stream, **kwargs): def completions_create(
self, prompt, stop_sequences, stream, **kwargs
):
return MockAnthropicResponse() return MockAnthropicResponse()
@ -45,7 +46,9 @@ def test_anthropic_init_default_values(anthropic_instance):
assert anthropic_instance.streaming is False assert anthropic_instance.streaming is False
assert anthropic_instance.default_request_timeout == 600 assert anthropic_instance.default_request_timeout == 600
assert ( assert (
anthropic_instance.anthropic_api_url == "https://test.anthropic.com") anthropic_instance.anthropic_api_url
== "https://test.anthropic.com"
)
assert anthropic_instance.anthropic_api_key == "test_api_key" assert anthropic_instance.anthropic_api_key == "test_api_key"
@ -76,8 +79,9 @@ def test_anthropic_default_params(anthropic_instance):
} }
def test_anthropic_run(mock_anthropic_env, mock_requests_post, def test_anthropic_run(
anthropic_instance): mock_anthropic_env, mock_requests_post, anthropic_instance
):
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"completion": "Generated text"} mock_response.json.return_value = {"completion": "Generated text"}
mock_requests_post.return_value = mock_response mock_requests_post.return_value = mock_response
@ -101,8 +105,9 @@ def test_anthropic_run(mock_anthropic_env, mock_requests_post,
) )
def test_anthropic_call(mock_anthropic_env, mock_requests_post, def test_anthropic_call(
anthropic_instance): mock_anthropic_env, mock_requests_post, anthropic_instance
):
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"completion": "Generated text"} mock_response.json.return_value = {"completion": "Generated text"}
mock_requests_post.return_value = mock_response mock_requests_post.return_value = mock_response
@ -126,8 +131,9 @@ def test_anthropic_call(mock_anthropic_env, mock_requests_post,
) )
def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post, def test_anthropic_exception_handling(
anthropic_instance): mock_anthropic_env, mock_requests_post, anthropic_instance
):
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"error": "An error occurred"} mock_response.json.return_value = {"error": "An error occurred"}
mock_requests_post.return_value = mock_response mock_requests_post.return_value = mock_response
@ -142,7 +148,6 @@ def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post,
class MockAnthropicResponse: class MockAnthropicResponse:
def __init__(self): def __init__(self):
self.completion = "Mocked Response from Anthropic" self.completion = "Mocked Response from Anthropic"
@ -168,7 +173,9 @@ def test_anthropic_async_call_method(anthropic_instance):
def test_anthropic_async_stream_method(anthropic_instance): def test_anthropic_async_stream_method(anthropic_instance):
async_generator = anthropic_instance.async_stream("Translate to French.") async_generator = anthropic_instance.async_stream(
"Translate to French."
)
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
@ -192,51 +199,63 @@ def test_anthropic_wrap_prompt(anthropic_instance):
def test_anthropic_convert_prompt(anthropic_instance): def test_anthropic_convert_prompt(anthropic_instance):
prompt = "What is the meaning of life?" prompt = "What is the meaning of life?"
converted_prompt = anthropic_instance.convert_prompt(prompt) converted_prompt = anthropic_instance.convert_prompt(prompt)
assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert converted_prompt.startswith(
anthropic_instance.HUMAN_PROMPT
)
assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT)
def test_anthropic_call_with_stop(anthropic_instance): def test_anthropic_call_with_stop(anthropic_instance):
response = anthropic_instance("Translate to French.", response = anthropic_instance(
stop=["stop1", "stop2"]) "Translate to French.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Anthropic" assert response == "Mocked Response from Anthropic"
def test_anthropic_stream_with_stop(anthropic_instance): def test_anthropic_stream_with_stop(anthropic_instance):
generator = anthropic_instance.stream("Write a story.", generator = anthropic_instance.stream(
stop=["stop1", "stop2"]) "Write a story.", stop=["stop1", "stop2"]
)
for token in generator: for token in generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_anthropic_async_call_with_stop(anthropic_instance): def test_anthropic_async_call_with_stop(anthropic_instance):
response = anthropic_instance.async_call("Tell me a joke.", response = anthropic_instance.async_call(
stop=["stop1", "stop2"]) "Tell me a joke.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Anthropic" assert response == "Mocked Response from Anthropic"
def test_anthropic_async_stream_with_stop(anthropic_instance): def test_anthropic_async_stream_with_stop(anthropic_instance):
async_generator = anthropic_instance.async_stream("Translate to French.", async_generator = anthropic_instance.async_stream(
stop=["stop1", "stop2"]) "Translate to French.", stop=["stop1", "stop2"]
)
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance,): def test_anthropic_get_num_tokens_with_count_tokens(
anthropic_instance,
):
anthropic_instance.count_tokens = Mock(return_value=10) anthropic_instance.count_tokens = Mock(return_value=10)
text = "This is a test sentence." text = "This is a test sentence."
num_tokens = anthropic_instance.get_num_tokens(text) num_tokens = anthropic_instance.get_num_tokens(text)
assert num_tokens == 10 assert num_tokens == 10
def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance,): def test_anthropic_get_num_tokens_without_count_tokens(
anthropic_instance,
):
del anthropic_instance.count_tokens del anthropic_instance.count_tokens
with pytest.raises(NameError): with pytest.raises(NameError):
text = "This is a test sentence." text = "This is a test sentence."
anthropic_instance.get_num_tokens(text) anthropic_instance.get_num_tokens(text)
def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance,): def test_anthropic_wrap_prompt_without_human_ai_prompt(
anthropic_instance,
):
del anthropic_instance.HUMAN_PROMPT del anthropic_instance.HUMAN_PROMPT
del anthropic_instance.AI_PROMPT del anthropic_instance.AI_PROMPT
prompt = "What is the meaning of life?" prompt = "What is the meaning of life?"

@ -48,8 +48,10 @@ def test_cell_biology_response(biogpt_instance):
# 40. Test for a question about protein structure # 40. Test for a question about protein structure
def test_protein_structure_response(biogpt_instance): def test_protein_structure_response(biogpt_instance):
question = ("What's the difference between alpha helix and beta sheet" question = (
" structures in proteins?") "What's the difference between alpha helix and beta sheet"
" structures in proteins?"
)
response = biogpt_instance(question) response = biogpt_instance(question)
assert response assert response
assert isinstance(response, str) assert isinstance(response, str)
@ -81,7 +83,9 @@ def test_bioinformatics_response(biogpt_instance):
# 44. Test for a neuroscience question # 44. Test for a neuroscience question
def test_neuroscience_response(biogpt_instance): def test_neuroscience_response(biogpt_instance):
question = ("Explain the function of synapses in the nervous system.") question = (
"Explain the function of synapses in the nervous system."
)
response = biogpt_instance(question) response = biogpt_instance(question)
assert response assert response
assert isinstance(response, str) assert isinstance(response, str)
@ -104,11 +108,8 @@ def test_init(bio_gpt):
def test_call(bio_gpt, monkeypatch): def test_call(bio_gpt, monkeypatch):
def mock_pipeline(*args, **kwargs): def mock_pipeline(*args, **kwargs):
class MockGenerator: class MockGenerator:
def __call__(self, text, **kwargs): def __call__(self, text, **kwargs):
return ["Generated text"] return ["Generated text"]
@ -166,7 +167,9 @@ def test_get_config_return_type(biogpt_instance):
# 28. Test saving model functionality by checking if files are created # 28. Test saving model functionality by checking if files are created
@patch.object(BioGptForCausalLM, "save_pretrained") @patch.object(BioGptForCausalLM, "save_pretrained")
@patch.object(BioGptTokenizer, "save_pretrained") @patch.object(BioGptTokenizer, "save_pretrained")
def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): def test_save_model(
mock_save_model, mock_save_tokenizer, biogpt_instance
):
path = "test_path" path = "test_path"
biogpt_instance.save_model(path) biogpt_instance.save_model(path)
mock_save_model.assert_called_once_with(path) mock_save_model.assert_called_once_with(path)
@ -176,7 +179,9 @@ def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance):
# 29. Test loading model from path # 29. Test loading model from path
@patch.object(BioGptForCausalLM, "from_pretrained") @patch.object(BioGptForCausalLM, "from_pretrained")
@patch.object(BioGptTokenizer, "from_pretrained") @patch.object(BioGptTokenizer, "from_pretrained")
def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance): def test_load_from_path(
mock_load_model, mock_load_tokenizer, biogpt_instance
):
path = "test_path" path = "test_path"
biogpt_instance.load_from_path(path) biogpt_instance.load_from_path(path)
mock_load_model.assert_called_once_with(path) mock_load_model.assert_called_once_with(path)
@ -193,7 +198,9 @@ def test_print_model_metadata(biogpt_instance):
# 31. Test that beam_search_decoding uses the correct number of beams # 31. Test that beam_search_decoding uses the correct number of beams
@patch.object(BioGptForCausalLM, "generate") @patch.object(BioGptForCausalLM, "generate")
def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): def test_beam_search_decoding_num_beams(
mock_generate, biogpt_instance
):
biogpt_instance.beam_search_decoding("test_sentence", num_beams=7) biogpt_instance.beam_search_decoding("test_sentence", num_beams=7)
_, kwargs = mock_generate.call_args _, kwargs = mock_generate.call_args
assert kwargs["num_beams"] == 7 assert kwargs["num_beams"] == 7
@ -201,8 +208,12 @@ def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance):
# 32. Test if beam_search_decoding handles early_stopping # 32. Test if beam_search_decoding handles early_stopping
@patch.object(BioGptForCausalLM, "generate") @patch.object(BioGptForCausalLM, "generate")
def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance): def test_beam_search_decoding_early_stopping(
biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False) mock_generate, biogpt_instance
):
biogpt_instance.beam_search_decoding(
"test_sentence", early_stopping=False
)
_, kwargs = mock_generate.call_args _, kwargs = mock_generate.call_args
assert kwargs["early_stopping"] is False assert kwargs["early_stopping"] is False

@ -42,7 +42,9 @@ def test_cohere_async_api_error_handling(cohere_instance):
cohere_instance.model = "base" cohere_instance.model = "base"
cohere_instance.cohere_api_key = "invalid-api-key" cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception): with pytest.raises(Exception):
cohere_instance.async_call("Error handling with invalid API key.") cohere_instance.async_call(
"Error handling with invalid API key."
)
def test_cohere_stream_api_error_handling(cohere_instance): def test_cohere_stream_api_error_handling(cohere_instance):
@ -51,7 +53,8 @@ def test_cohere_stream_api_error_handling(cohere_instance):
cohere_instance.cohere_api_key = "invalid-api-key" cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception): with pytest.raises(Exception):
generator = cohere_instance.stream( generator = cohere_instance.stream(
"Error handling with invalid API key.") "Error handling with invalid API key."
)
for token in generator: for token in generator:
pass pass
@ -91,26 +94,31 @@ def test_cohere_convert_prompt(cohere_instance):
def test_cohere_call_with_stop(cohere_instance): def test_cohere_call_with_stop(cohere_instance):
response = cohere_instance("Translate to French.", stop=["stop1", "stop2"]) response = cohere_instance(
"Translate to French.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Cohere" assert response == "Mocked Response from Cohere"
def test_cohere_stream_with_stop(cohere_instance): def test_cohere_stream_with_stop(cohere_instance):
generator = cohere_instance.stream("Write a story.", generator = cohere_instance.stream(
stop=["stop1", "stop2"]) "Write a story.", stop=["stop1", "stop2"]
)
for token in generator: for token in generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_async_call_with_stop(cohere_instance): def test_cohere_async_call_with_stop(cohere_instance):
response = cohere_instance.async_call("Tell me a joke.", response = cohere_instance.async_call(
stop=["stop1", "stop2"]) "Tell me a joke.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Cohere" assert response == "Mocked Response from Cohere"
def test_cohere_async_stream_with_stop(cohere_instance): def test_cohere_async_stream_with_stop(cohere_instance):
async_generator = cohere_instance.async_stream("Translate to French.", async_generator = cohere_instance.async_stream(
stop=["stop1", "stop2"]) "Translate to French.", stop=["stop1", "stop2"]
)
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
@ -166,8 +174,12 @@ def test_base_cohere_validate_environment_without_cohere():
# Test cases for benchmarking generations with various models # Test cases for benchmarking generations with various models
def test_cohere_generate_with_command_light(cohere_instance): def test_cohere_generate_with_command_light(cohere_instance):
cohere_instance.model = "command-light" cohere_instance.model = "command-light"
response = cohere_instance("Generate text with Command Light model.") response = cohere_instance(
assert response.startswith("Generated text with Command Light model") "Generate text with Command Light model."
)
assert response.startswith(
"Generated text with Command Light model"
)
def test_cohere_generate_with_command(cohere_instance): def test_cohere_generate_with_command(cohere_instance):
@ -190,54 +202,74 @@ def test_cohere_generate_with_base(cohere_instance):
def test_cohere_generate_with_embed_english_v2(cohere_instance): def test_cohere_generate_with_embed_english_v2(cohere_instance):
cohere_instance.model = "embed-english-v2.0" cohere_instance.model = "embed-english-v2.0"
response = cohere_instance("Generate embeddings with English v2.0 model.") response = cohere_instance(
assert response.startswith("Generated embeddings with English v2.0 model") "Generate embeddings with English v2.0 model."
)
assert response.startswith(
"Generated embeddings with English v2.0 model"
)
def test_cohere_generate_with_embed_english_light_v2(cohere_instance): def test_cohere_generate_with_embed_english_light_v2(cohere_instance):
cohere_instance.model = "embed-english-light-v2.0" cohere_instance.model = "embed-english-light-v2.0"
response = cohere_instance( response = cohere_instance(
"Generate embeddings with English Light v2.0 model.") "Generate embeddings with English Light v2.0 model."
)
assert response.startswith( assert response.startswith(
"Generated embeddings with English Light v2.0 model") "Generated embeddings with English Light v2.0 model"
)
def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): def test_cohere_generate_with_embed_multilingual_v2(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0" cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance( response = cohere_instance(
"Generate embeddings with Multilingual v2.0 model.") "Generate embeddings with Multilingual v2.0 model."
)
assert response.startswith( assert response.startswith(
"Generated embeddings with Multilingual v2.0 model") "Generated embeddings with Multilingual v2.0 model"
)
def test_cohere_generate_with_embed_english_v3(cohere_instance): def test_cohere_generate_with_embed_english_v3(cohere_instance):
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
response = cohere_instance("Generate embeddings with English v3.0 model.") response = cohere_instance(
assert response.startswith("Generated embeddings with English v3.0 model") "Generate embeddings with English v3.0 model."
)
assert response.startswith(
"Generated embeddings with English v3.0 model"
)
def test_cohere_generate_with_embed_english_light_v3(cohere_instance): def test_cohere_generate_with_embed_english_light_v3(cohere_instance):
cohere_instance.model = "embed-english-light-v3.0" cohere_instance.model = "embed-english-light-v3.0"
response = cohere_instance( response = cohere_instance(
"Generate embeddings with English Light v3.0 model.") "Generate embeddings with English Light v3.0 model."
)
assert response.startswith( assert response.startswith(
"Generated embeddings with English Light v3.0 model") "Generated embeddings with English Light v3.0 model"
)
def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): def test_cohere_generate_with_embed_multilingual_v3(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance( response = cohere_instance(
"Generate embeddings with Multilingual v3.0 model.") "Generate embeddings with Multilingual v3.0 model."
)
assert response.startswith( assert response.startswith(
"Generated embeddings with Multilingual v3.0 model") "Generated embeddings with Multilingual v3.0 model"
)
def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance,): def test_cohere_generate_with_embed_multilingual_light_v3(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.model = "embed-multilingual-light-v3.0"
response = cohere_instance( response = cohere_instance(
"Generate embeddings with Multilingual Light v3.0 model.") "Generate embeddings with Multilingual Light v3.0 model."
)
assert response.startswith( assert response.startswith(
"Generated embeddings with Multilingual Light v3.0 model") "Generated embeddings with Multilingual Light v3.0 model"
)
# Add more test cases to benchmark other models and functionalities # Add more test cases to benchmark other models and functionalities
@ -267,13 +299,17 @@ def test_cohere_call_with_embed_english_v3_model(cohere_instance):
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance,): def test_cohere_call_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0" cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance("Translate to French.") response = cohere_instance("Translate to French.")
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance,): def test_cohere_call_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance("Translate to French.") response = cohere_instance("Translate to French.")
assert isinstance(response, str) assert isinstance(response, str)
@ -293,7 +329,9 @@ def test_cohere_call_with_long_prompt(cohere_instance):
def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance):
cohere_instance.max_tokens = 10 cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit.") prompt = (
"This is a test prompt that will exceed the max tokens limit."
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cohere_instance(prompt) cohere_instance(prompt)
@ -326,14 +364,18 @@ def test_cohere_stream_with_embed_english_v3_model(cohere_instance):
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance,): def test_cohere_stream_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0" cohere_instance.model = "embed-multilingual-v2.0"
generator = cohere_instance.stream("Write a story.") generator = cohere_instance.stream("Write a story.")
for token in generator: for token in generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance,): def test_cohere_stream_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
generator = cohere_instance.stream("Write a story.") generator = cohere_instance.stream("Write a story.")
for token in generator: for token in generator:
@ -352,25 +394,33 @@ def test_cohere_async_call_with_base_model(cohere_instance):
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_async_call_with_embed_english_v2_model(cohere_instance,): def test_cohere_async_call_with_embed_english_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v2.0" cohere_instance.model = "embed-english-v2.0"
response = cohere_instance.async_call("Translate to French.") response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_async_call_with_embed_english_v3_model(cohere_instance,): def test_cohere_async_call_with_embed_english_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
response = cohere_instance.async_call("Translate to French.") response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance,): def test_cohere_async_call_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0" cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance.async_call("Translate to French.") response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance,): def test_cohere_async_call_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance.async_call("Translate to French.") response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str) assert isinstance(response, str)
@ -390,28 +440,36 @@ def test_cohere_async_stream_with_base_model(cohere_instance):
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance,): def test_cohere_async_stream_with_embed_english_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v2.0" cohere_instance.model = "embed-english-v2.0"
async_generator = cohere_instance.async_stream("Write a story.") async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance,): def test_cohere_async_stream_with_embed_english_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
async_generator = cohere_instance.async_stream("Write a story.") async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance,): def test_cohere_async_stream_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0" cohere_instance.model = "embed-multilingual-v2.0"
async_generator = cohere_instance.async_stream("Write a story.") async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,): def test_cohere_async_stream_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
async_generator = cohere_instance.async_stream("Write a story.") async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator: for token in async_generator:
@ -421,7 +479,9 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,):
def test_cohere_representation_model_embedding(cohere_instance): def test_cohere_representation_model_embedding(cohere_instance):
# Test using the Representation model for text embedding # Test using the Representation model for text embedding
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
embedding = cohere_instance.embed("Generate an embedding for this text.") embedding = cohere_instance.embed(
"Generate an embedding for this text."
)
assert isinstance(embedding, list) assert isinstance(embedding, list)
assert len(embedding) > 0 assert len(embedding) > 0
@ -435,20 +495,26 @@ def test_cohere_representation_model_classification(cohere_instance):
assert "score" in classification assert "score" in classification
def test_cohere_representation_model_language_detection(cohere_instance,): def test_cohere_representation_model_language_detection(
cohere_instance,
):
# Test using the Representation model for language detection # Test using the Representation model for language detection
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
language = cohere_instance.detect_language( language = cohere_instance.detect_language(
"Detect the language of this text.") "Detect the language of this text."
)
assert isinstance(language, str) assert isinstance(language, str)
def test_cohere_representation_model_max_tokens_limit_exceeded( def test_cohere_representation_model_max_tokens_limit_exceeded(
cohere_instance,): cohere_instance,
):
# Test handling max tokens limit exceeded error # Test handling max tokens limit exceeded error
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
cohere_instance.max_tokens = 10 cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit.") prompt = (
"This is a test prompt that will exceed the max tokens limit."
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cohere_instance.embed(prompt) cohere_instance.embed(prompt)
@ -456,80 +522,102 @@ def test_cohere_representation_model_max_tokens_limit_exceeded(
# Add more production-grade test cases based on real-world scenarios # Add more production-grade test cases based on real-world scenarios
def test_cohere_representation_model_multilingual_embedding(cohere_instance,): def test_cohere_representation_model_multilingual_embedding(
cohere_instance,
):
# Test using the Representation model for multilingual text embedding # Test using the Representation model for multilingual text embedding
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
embedding = cohere_instance.embed("Generate multilingual embeddings.") embedding = cohere_instance.embed(
"Generate multilingual embeddings."
)
assert isinstance(embedding, list) assert isinstance(embedding, list)
assert len(embedding) > 0 assert len(embedding) > 0
def test_cohere_representation_model_multilingual_classification( def test_cohere_representation_model_multilingual_classification(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for multilingual text classification # Test using the Representation model for multilingual text classification
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
classification = cohere_instance.classify("Classify multilingual text.") classification = cohere_instance.classify(
"Classify multilingual text."
)
assert isinstance(classification, dict) assert isinstance(classification, dict)
assert "class" in classification assert "class" in classification
assert "score" in classification assert "score" in classification
def test_cohere_representation_model_multilingual_language_detection( def test_cohere_representation_model_multilingual_language_detection(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for multilingual language detection # Test using the Representation model for multilingual language detection
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
language = cohere_instance.detect_language( language = cohere_instance.detect_language(
"Detect the language of multilingual text.") "Detect the language of multilingual text."
)
assert isinstance(language, str) assert isinstance(language, str)
def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded(
cohere_instance,): cohere_instance,
):
# Test handling max tokens limit exceeded error for multilingual model # Test handling max tokens limit exceeded error for multilingual model
cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.model = "embed-multilingual-v3.0"
cohere_instance.max_tokens = 10 cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit" prompt = (
" for multilingual model.") "This is a test prompt that will exceed the max tokens limit"
" for multilingual model."
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cohere_instance.embed(prompt) cohere_instance.embed(prompt)
def test_cohere_representation_model_multilingual_light_embedding( def test_cohere_representation_model_multilingual_light_embedding(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for multilingual light text embedding # Test using the Representation model for multilingual light text embedding
cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.model = "embed-multilingual-light-v3.0"
embedding = cohere_instance.embed("Generate multilingual light embeddings.") embedding = cohere_instance.embed(
"Generate multilingual light embeddings."
)
assert isinstance(embedding, list) assert isinstance(embedding, list)
assert len(embedding) > 0 assert len(embedding) > 0
def test_cohere_representation_model_multilingual_light_classification( def test_cohere_representation_model_multilingual_light_classification(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for multilingual light text classification # Test using the Representation model for multilingual light text classification
cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.model = "embed-multilingual-light-v3.0"
classification = cohere_instance.classify( classification = cohere_instance.classify(
"Classify multilingual light text.") "Classify multilingual light text."
)
assert isinstance(classification, dict) assert isinstance(classification, dict)
assert "class" in classification assert "class" in classification
assert "score" in classification assert "score" in classification
def test_cohere_representation_model_multilingual_light_language_detection( def test_cohere_representation_model_multilingual_light_language_detection(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for multilingual light language detection # Test using the Representation model for multilingual light language detection
cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.model = "embed-multilingual-light-v3.0"
language = cohere_instance.detect_language( language = cohere_instance.detect_language(
"Detect the language of multilingual light text.") "Detect the language of multilingual light text."
)
assert isinstance(language, str) assert isinstance(language, str)
def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded( def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded(
cohere_instance,): cohere_instance,
):
# Test handling max tokens limit exceeded error for multilingual light model # Test handling max tokens limit exceeded error for multilingual light model
cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.model = "embed-multilingual-light-v3.0"
cohere_instance.max_tokens = 10 cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit" prompt = (
" for multilingual light model.") "This is a test prompt that will exceed the max tokens limit"
" for multilingual light model."
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cohere_instance.embed(prompt) cohere_instance.embed(prompt)
@ -537,14 +625,18 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede
def test_cohere_command_light_model(cohere_instance): def test_cohere_command_light_model(cohere_instance):
# Test using the Command Light model for text generation # Test using the Command Light model for text generation
cohere_instance.model = "command-light" cohere_instance.model = "command-light"
response = cohere_instance("Generate text using Command Light model.") response = cohere_instance(
"Generate text using Command Light model."
)
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_base_light_model(cohere_instance): def test_cohere_base_light_model(cohere_instance):
# Test using the Base Light model for text generation # Test using the Base Light model for text generation
cohere_instance.model = "base-light" cohere_instance.model = "base-light"
response = cohere_instance("Generate text using Base Light model.") response = cohere_instance(
"Generate text using Base Light model."
)
assert isinstance(response, str) assert isinstance(response, str)
@ -555,7 +647,9 @@ def test_cohere_generate_summarize_endpoint(cohere_instance):
assert isinstance(response, str) assert isinstance(response, str)
def test_cohere_representation_model_english_embedding(cohere_instance,): def test_cohere_representation_model_english_embedding(
cohere_instance,
):
# Test using the Representation model for English text embedding # Test using the Representation model for English text embedding
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
embedding = cohere_instance.embed("Generate English embeddings.") embedding = cohere_instance.embed("Generate English embeddings.")
@ -563,69 +657,90 @@ def test_cohere_representation_model_english_embedding(cohere_instance,):
assert len(embedding) > 0 assert len(embedding) > 0
def test_cohere_representation_model_english_classification(cohere_instance,): def test_cohere_representation_model_english_classification(
cohere_instance,
):
# Test using the Representation model for English text classification # Test using the Representation model for English text classification
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
classification = cohere_instance.classify("Classify English text.") classification = cohere_instance.classify(
"Classify English text."
)
assert isinstance(classification, dict) assert isinstance(classification, dict)
assert "class" in classification assert "class" in classification
assert "score" in classification assert "score" in classification
def test_cohere_representation_model_english_language_detection( def test_cohere_representation_model_english_language_detection(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for English language detection # Test using the Representation model for English language detection
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
language = cohere_instance.detect_language( language = cohere_instance.detect_language(
"Detect the language of English text.") "Detect the language of English text."
)
assert isinstance(language, str) assert isinstance(language, str)
def test_cohere_representation_model_english_max_tokens_limit_exceeded( def test_cohere_representation_model_english_max_tokens_limit_exceeded(
cohere_instance,): cohere_instance,
):
# Test handling max tokens limit exceeded error for English model # Test handling max tokens limit exceeded error for English model
cohere_instance.model = "embed-english-v3.0" cohere_instance.model = "embed-english-v3.0"
cohere_instance.max_tokens = 10 cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit" prompt = (
" for English model.") "This is a test prompt that will exceed the max tokens limit"
" for English model."
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cohere_instance.embed(prompt) cohere_instance.embed(prompt)
def test_cohere_representation_model_english_light_embedding(cohere_instance,): def test_cohere_representation_model_english_light_embedding(
cohere_instance,
):
# Test using the Representation model for English light text embedding # Test using the Representation model for English light text embedding
cohere_instance.model = "embed-english-light-v3.0" cohere_instance.model = "embed-english-light-v3.0"
embedding = cohere_instance.embed("Generate English light embeddings.") embedding = cohere_instance.embed(
"Generate English light embeddings."
)
assert isinstance(embedding, list) assert isinstance(embedding, list)
assert len(embedding) > 0 assert len(embedding) > 0
def test_cohere_representation_model_english_light_classification( def test_cohere_representation_model_english_light_classification(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for English light text classification # Test using the Representation model for English light text classification
cohere_instance.model = "embed-english-light-v3.0" cohere_instance.model = "embed-english-light-v3.0"
classification = cohere_instance.classify("Classify English light text.") classification = cohere_instance.classify(
"Classify English light text."
)
assert isinstance(classification, dict) assert isinstance(classification, dict)
assert "class" in classification assert "class" in classification
assert "score" in classification assert "score" in classification
def test_cohere_representation_model_english_light_language_detection( def test_cohere_representation_model_english_light_language_detection(
cohere_instance,): cohere_instance,
):
# Test using the Representation model for English light language detection # Test using the Representation model for English light language detection
cohere_instance.model = "embed-english-light-v3.0" cohere_instance.model = "embed-english-light-v3.0"
language = cohere_instance.detect_language( language = cohere_instance.detect_language(
"Detect the language of English light text.") "Detect the language of English light text."
)
assert isinstance(language, str) assert isinstance(language, str)
def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
cohere_instance,): cohere_instance,
):
# Test handling max tokens limit exceeded error for English light model # Test handling max tokens limit exceeded error for English light model
cohere_instance.model = "embed-english-light-v3.0" cohere_instance.model = "embed-english-light-v3.0"
cohere_instance.max_tokens = 10 cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit" prompt = (
" for English light model.") "This is a test prompt that will exceed the max tokens limit"
" for English light model."
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cohere_instance.embed(prompt) cohere_instance.embed(prompt)
@ -633,7 +748,9 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
def test_cohere_command_model(cohere_instance): def test_cohere_command_model(cohere_instance):
# Test using the Command model for text generation # Test using the Command model for text generation
cohere_instance.model = "command" cohere_instance.model = "command"
response = cohere_instance("Generate text using the Command model.") response = cohere_instance(
"Generate text using the Command model."
)
assert isinstance(response, str) assert isinstance(response, str)
@ -647,7 +764,9 @@ def test_cohere_invalid_model(cohere_instance):
cohere_instance("Generate text using an invalid model.") cohere_instance("Generate text using an invalid model.")
def test_cohere_base_model_generation_with_max_tokens(cohere_instance,): def test_cohere_base_model_generation_with_max_tokens(
cohere_instance,
):
# Test generating text using the base model with a specified max_tokens limit # Test generating text using the base model with a specified max_tokens limit
cohere_instance.model = "base" cohere_instance.model = "base"
cohere_instance.max_tokens = 20 cohere_instance.max_tokens = 20

@ -30,30 +30,45 @@ def test_run_text_to_speech(eleven_labs_tool):
def test_play_speech(eleven_labs_tool): def test_play_speech(eleven_labs_tool):
with patch("builtins.open", mock_open(read_data="fake_audio_data")): with patch(
"builtins.open", mock_open(read_data="fake_audio_data")
):
eleven_labs_tool.play(EXPECTED_SPEECH_FILE) eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
def test_stream_speech(eleven_labs_tool): def test_stream_speech(eleven_labs_tool):
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file: with patch(
"tempfile.NamedTemporaryFile", mock_open()
) as mock_file:
eleven_labs_tool.stream_speech(SAMPLE_TEXT) eleven_labs_tool.stream_speech(SAMPLE_TEXT)
mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False) mock_file.assert_called_with(
mode="bx", suffix=".wav", delete=False
)
# Testing fixture and environment variables # Testing fixture and environment variables
def test_api_key_validation(eleven_labs_tool): def test_api_key_validation(eleven_labs_tool):
with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY): with patch(
"langchain.utils.get_from_dict_or_env", return_value=API_KEY
):
values = {"eleven_api_key": None} values = {"eleven_api_key": None}
validated_values = eleven_labs_tool.validate_environment(values) validated_values = eleven_labs_tool.validate_environment(
values
)
assert "eleven_api_key" in validated_values assert "eleven_api_key" in validated_values
# Mocking the external library # Mocking the external library
def test_run_text_to_speech_with_mock(eleven_labs_tool): def test_run_text_to_speech_with_mock(eleven_labs_tool):
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch( with patch(
"your_module._import_elevenlabs") as mock_elevenlabs: "tempfile.NamedTemporaryFile", mock_open()
) as mock_file, patch(
"your_module._import_elevenlabs"
) as mock_elevenlabs:
mock_elevenlabs_instance = mock_elevenlabs.return_value mock_elevenlabs_instance = mock_elevenlabs.return_value
mock_elevenlabs_instance.generate.return_value = (b"fake_audio_data") mock_elevenlabs_instance.generate.return_value = (
b"fake_audio_data"
)
eleven_labs_tool.run(SAMPLE_TEXT) eleven_labs_tool.run(SAMPLE_TEXT)
assert mock_file.call_args[1]["suffix"] == ".wav" assert mock_file.call_args[1]["suffix"] == ".wav"
assert mock_file.call_args[1]["delete"] is False assert mock_file.call_args[1]["delete"] is False
@ -65,11 +80,14 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool):
with patch("your_module._import_elevenlabs") as mock_elevenlabs: with patch("your_module._import_elevenlabs") as mock_elevenlabs:
mock_elevenlabs_instance = mock_elevenlabs.return_value mock_elevenlabs_instance = mock_elevenlabs.return_value
mock_elevenlabs_instance.generate.side_effect = Exception( mock_elevenlabs_instance.generate.side_effect = Exception(
"Test Exception") "Test Exception"
)
with pytest.raises( with pytest.raises(
RuntimeError, RuntimeError,
match=("Error while running ElevenLabsText2SpeechTool: Test" match=(
" Exception"), "Error while running ElevenLabsText2SpeechTool: Test"
" Exception"
),
): ):
eleven_labs_tool.run(SAMPLE_TEXT) eleven_labs_tool.run(SAMPLE_TEXT)
@ -79,7 +97,9 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool):
"model", "model",
[ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL], [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL],
) )
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model): def test_run_text_to_speech_with_different_models(
eleven_labs_tool, model
):
eleven_labs_tool.model = model eleven_labs_tool.model = model
speech_file = eleven_labs_tool.run(SAMPLE_TEXT) speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
assert isinstance(speech_file, str) assert isinstance(speech_file, str)

@ -39,4 +39,6 @@ def test_fire_function_caller_run(mocker):
tokenizer.batch_decode.assert_called_once_with(generated_ids) tokenizer.batch_decode.assert_called_once_with(generated_ids)
# Assert the decoded output is printed # Assert the decoded output is printed
assert decoded_output in mocker.patch.object(print, "call_args_list") assert decoded_output in mocker.patch.object(
print, "call_args_list"
)

@ -38,7 +38,9 @@ def fuyu_instance():
# Test using the fixture. # Test using the fixture.
def test_fuyu_processor_initialization(fuyu_instance): def test_fuyu_processor_initialization(fuyu_instance):
assert isinstance(fuyu_instance.processor, FuyuProcessor) assert isinstance(fuyu_instance.processor, FuyuProcessor)
assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) assert isinstance(
fuyu_instance.image_processor, FuyuImageProcessor
)
# Test exception when providing an invalid image path. # Test exception when providing an invalid image path.
@ -49,7 +51,6 @@ def test_invalid_image_path(fuyu_instance):
# Using monkeypatch to replace the Image.open method to simulate a failure. # Using monkeypatch to replace the Image.open method to simulate a failure.
def test_image_open_failure(fuyu_instance, monkeypatch): def test_image_open_failure(fuyu_instance, monkeypatch):
def mock_open(*args, **kwargs): def mock_open(*args, **kwargs):
raise Exception("Mocked failure") raise Exception("Mocked failure")
@ -78,9 +79,13 @@ def test_tokenizer_type(fuyu_instance):
def test_processor_has_image_processor_and_tokenizer(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance):
assert (fuyu_instance.processor.image_processor == assert (
fuyu_instance.image_processor) fuyu_instance.processor.image_processor
assert (fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer) == fuyu_instance.image_processor
)
assert (
fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer
)
def test_model_device_map(fuyu_instance): def test_model_device_map(fuyu_instance):
@ -139,14 +144,22 @@ def test_get_img_invalid_path(fuyu_instance):
# Test `run` method with valid inputs # Test `run` method with valid inputs
def test_run_valid_inputs(fuyu_instance): def test_run_valid_inputs(fuyu_instance):
with patch.object(fuyu_instance, "get_img") as mock_get_img, patch.object( with patch.object(
fuyu_instance, "processor") as mock_processor, patch.object( fuyu_instance, "get_img"
fuyu_instance, "model") as mock_model: ) as mock_get_img, patch.object(
fuyu_instance, "processor"
) as mock_processor, patch.object(
fuyu_instance, "model"
) as mock_model:
mock_get_img.return_value = "Test image" mock_get_img.return_value = "Test image"
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test text"] mock_processor.batch_decode.return_value = ["Test text"]
result = fuyu_instance.run("Hello, world!", "valid/path/to/image.png") result = fuyu_instance.run(
"Hello, world!", "valid/path/to/image.png"
)
assert result == ["Test text"] assert result == ["Test text"]
@ -173,7 +186,9 @@ def test_run_invalid_image_path(fuyu_instance):
with patch.object(fuyu_instance, "get_img") as mock_get_img: with patch.object(fuyu_instance, "get_img") as mock_get_img:
mock_get_img.side_effect = FileNotFoundError mock_get_img.side_effect = FileNotFoundError
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
fuyu_instance.run("Hello, world!", "invalid/path/to/image.png") fuyu_instance.run(
"Hello, world!", "invalid/path/to/image.png"
)
# Test `__init__` method with default parameters # Test `__init__` method with default parameters

@ -24,8 +24,12 @@ def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model):
assert model.model is mock_genai_model assert model.model is mock_genai_model
def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model): def test_gemini_init_custom_params(
model = Gemini(model_name="custom-model", gemini_api_key="custom-api-key") mock_gemini_api_key, mock_genai_model
):
model = Gemini(
model_name="custom-model", gemini_api_key="custom-api-key"
)
assert model.model_name == "custom-model" assert model.model_name == "custom-model"
assert model.gemini_api_key == "custom-api-key" assert model.gemini_api_key == "custom-api-key"
assert model.model is mock_genai_model assert model.model is mock_genai_model
@ -50,13 +54,16 @@ def test_gemini_run_with_img(
response = model.run(task=task, img=img) response = model.run(task=task, img=img)
assert response == "Generated response" assert response == "Generated response"
mock_generate_content.assert_called_with(content=[task, "Processed image"]) mock_generate_content.assert_called_with(
content=[task, "Processed image"]
)
mock_process_img.assert_called_with(img=img) mock_process_img.assert_called_with(img=img)
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") @patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key, def test_gemini_run_without_img(
mock_genai_model): mock_generate_content, mock_gemini_api_key, mock_genai_model
):
model = Gemini() model = Gemini()
task = "A cat" task = "A cat"
response_mock = Mock(text="Generated response") response_mock = Mock(text="Generated response")
@ -69,8 +76,9 @@ def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key,
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") @patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_exception(mock_generate_content, mock_gemini_api_key, def test_gemini_run_exception(
mock_genai_model): mock_generate_content, mock_gemini_api_key, mock_genai_model
):
model = Gemini() model = Gemini()
task = "A cat" task = "A cat"
mock_generate_content.side_effect = Exception("Test exception") mock_generate_content.side_effect = Exception("Test exception")
@ -88,23 +96,30 @@ def test_gemini_process_img(mock_gemini_api_key, mock_genai_model):
with patch("builtins.open", create=True) as open_mock: with patch("builtins.open", create=True) as open_mock:
open_mock.return_value.__enter__.return_value.read.return_value = ( open_mock.return_value.__enter__.return_value.read.return_value = (
img_data) img_data
)
processed_img = model.process_img(img) processed_img = model.process_img(img)
assert processed_img == [{"mime_type": "image/png", "data": img_data}] assert processed_img == [
{"mime_type": "image/png", "data": img_data}
]
open_mock.assert_called_with(img, "rb") open_mock.assert_called_with(img, "rb")
# Test Gemini initialization with missing API key # Test Gemini initialization with missing API key
def test_gemini_init_missing_api_key(): def test_gemini_init_missing_api_key():
with pytest.raises(ValueError, match="Please provide a Gemini API key"): with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
Gemini(gemini_api_key=None) Gemini(gemini_api_key=None)
# Test Gemini initialization with missing model name # Test Gemini initialization with missing model name
def test_gemini_init_missing_model_name(): def test_gemini_init_missing_model_name():
with pytest.raises(ValueError, match="Please provide a model name"): with pytest.raises(
ValueError, match="Please provide a model name"
):
Gemini(model_name=None) Gemini(model_name=None)
@ -126,20 +141,26 @@ def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model):
# Test Gemini process_img method with missing image # Test Gemini process_img method with missing image
def test_gemini_process_img_missing_image(mock_gemini_api_key, def test_gemini_process_img_missing_image(
mock_genai_model): mock_gemini_api_key, mock_genai_model
):
model = Gemini() model = Gemini()
img = None img = None
with pytest.raises(ValueError, match="Please provide an image to process"): with pytest.raises(
ValueError, match="Please provide an image to process"
):
model.process_img(img=img) model.process_img(img=img)
# Test Gemini process_img method with missing image type # Test Gemini process_img method with missing image type
def test_gemini_process_img_missing_image_type(mock_gemini_api_key, def test_gemini_process_img_missing_image_type(
mock_genai_model): mock_gemini_api_key, mock_genai_model
):
model = Gemini() model = Gemini()
img = "cat.png" img = "cat.png"
with pytest.raises(ValueError, match="Please provide the image type"): with pytest.raises(
ValueError, match="Please provide the image type"
):
model.process_img(img=img, type=None) model.process_img(img=img, type=None)
@ -147,7 +168,9 @@ def test_gemini_process_img_missing_image_type(mock_gemini_api_key,
def test_gemini_process_img_missing_api_key(mock_genai_model): def test_gemini_process_img_missing_api_key(mock_genai_model):
model = Gemini(gemini_api_key=None) model = Gemini(gemini_api_key=None)
img = "cat.png" img = "cat.png"
with pytest.raises(ValueError, match="Please provide a Gemini API key"): with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
model.process_img(img=img, type="image/png") model.process_img(img=img, type="image/png")
@ -170,7 +193,9 @@ def test_gemini_run_mock_img_processing(
response = model.run(task=task, img=img) response = model.run(task=task, img=img)
assert response == "Generated response" assert response == "Generated response"
mock_generate_content.assert_called_with(content=[task, "Processed image"]) mock_generate_content.assert_called_with(
content=[task, "Processed image"]
)
mock_process_img.assert_called_with(img=img) mock_process_img.assert_called_with(img=img)

@ -11,13 +11,16 @@ except ImportError:
@pytest.fixture @pytest.fixture
def api(): def api():
return Gigabind(host="localhost", port=8000, endpoint="embeddings") return Gigabind(
host="localhost", port=8000, endpoint="embeddings"
)
@pytest.fixture @pytest.fixture
def mock(requests_mock): def mock(requests_mock):
requests_mock.post("http://localhost:8000/embeddings", requests_mock.post(
json={"result": "success"}) "http://localhost:8000/embeddings", json={"result": "success"}
)
return requests_mock return requests_mock
@ -37,9 +40,9 @@ def test_run_with_audio(api, mock):
def test_run_with_all(api, mock): def test_run_with_all(api, mock):
response = api.run(text="Hello, world!", response = api.run(
vision="image.jpg", text="Hello, world!", vision="image.jpg", audio="audio.mp3"
audio="audio.mp3") )
assert response == {"result": "success"} assert response == {"result": "success"}
@ -62,20 +65,9 @@ def test_retry_on_failure(api, requests_mock):
requests_mock.post( requests_mock.post(
"http://localhost:8000/embeddings", "http://localhost:8000/embeddings",
[ [
{ {"status_code": 500, "json": {}},
"status_code": 500, {"status_code": 500, "json": {}},
"json": {} {"status_code": 200, "json": {"result": "success"}},
},
{
"status_code": 500,
"json": {}
},
{
"status_code": 200,
"json": {
"result": "success"
}
},
], ],
) )
response = api.run(text="Hello, world!") response = api.run(text="Hello, world!")
@ -86,18 +78,9 @@ def test_retry_exhausted(api, requests_mock):
requests_mock.post( requests_mock.post(
"http://localhost:8000/embeddings", "http://localhost:8000/embeddings",
[ [
{ {"status_code": 500, "json": {}},
"status_code": 500, {"status_code": 500, "json": {}},
"json": {} {"status_code": 500, "json": {}},
},
{
"status_code": 500,
"json": {}
},
{
"status_code": 500,
"json": {}
},
], ],
) )
response = api.run(text="Hello, world!") response = api.run(text="Hello, world!")
@ -110,7 +93,9 @@ def test_proxy_url(api):
def test_invalid_response(api, requests_mock): def test_invalid_response(api, requests_mock):
requests_mock.post("http://localhost:8000/embeddings", text="not json") requests_mock.post(
"http://localhost:8000/embeddings", text="not json"
)
response = api.run(text="Hello, world!") response = api.run(text="Hello, world!")
assert response is None assert response is None
@ -125,7 +110,9 @@ def test_connection_error(api, requests_mock):
def test_http_error(api, requests_mock): def test_http_error(api, requests_mock):
requests_mock.post("http://localhost:8000/embeddings", status_code=500) requests_mock.post(
"http://localhost:8000/embeddings", status_code=500
)
response = api.run(text="Hello, world!") response = api.run(text="Hello, world!")
assert response is None assert response is None
@ -161,7 +148,9 @@ def test_run_with_large_all(api, mock):
large_text = "Hello, world! " * 10000 # 10,000 repetitions large_text = "Hello, world! " * 10000 # 10,000 repetitions
large_vision = "image.jpg" * 10000 # 10,000 repetitions large_vision = "image.jpg" * 10000 # 10,000 repetitions
large_audio = "audio.mp3" * 10000 # 10,000 repetitions large_audio = "audio.mp3" * 10000 # 10,000 repetitions
response = api.run(text=large_text, vision=large_vision, audio=large_audio) response = api.run(
text=large_text, vision=large_vision, audio=large_audio
)
assert response == {"result": "success"} assert response == {"result": "success"}

@ -46,7 +46,9 @@ def test_run_success(vision_api):
def test_run_request_error(vision_api): def test_run_request_error(vision_api):
with patch("requests.post", side_effect=RequestException("Request Error")): with patch(
"requests.post", side_effect=RequestException("Request Error")
):
with pytest.raises(RequestException): with pytest.raises(RequestException):
vision_api.run("What is this?", img) vision_api.run("What is this?", img)
@ -62,7 +64,9 @@ def test_run_response_error(vision_api):
def test_call(vision_api): def test_call(vision_api):
expected_response = {"choices": [{"text": "This is the model's response."}]} expected_response = {
"choices": [{"text": "This is the model's response."}]
}
with patch( with patch(
"requests.post", "requests.post",
return_value=Mock(json=lambda: expected_response), return_value=Mock(json=lambda: expected_response),
@ -91,7 +95,9 @@ def test_initialization_with_custom_key():
def test_run_with_exception(gpt_api): def test_run_with_exception(gpt_api):
task = "What is in the image?" task = "What is in the image?"
img_url = img img_url = img
with patch("requests.post", side_effect=Exception("Test Exception")): with patch(
"requests.post", side_effect=Exception("Test Exception")
):
with pytest.raises(Exception): with pytest.raises(Exception):
gpt_api.run(task, img_url) gpt_api.run(task, img_url)
@ -99,10 +105,14 @@ def test_run_with_exception(gpt_api):
def test_call_method_successful_response(gpt_api): def test_call_method_successful_response(gpt_api):
task = "What is in the image?" task = "What is in the image?"
img_url = img img_url = img
response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]} response_json = {
"choices": [{"text": "Answer from GPT-4 Vision"}]
}
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = response_json mock_response.json.return_value = response_json
with patch("requests.post", return_value=mock_response) as mock_post: with patch(
"requests.post", return_value=mock_response
) as mock_post:
result = gpt_api(task, img_url) result = gpt_api(task, img_url)
mock_post.assert_called_once() mock_post.assert_called_once()
assert result == response_json assert result == response_json
@ -111,7 +121,9 @@ def test_call_method_successful_response(gpt_api):
def test_call_method_with_exception(gpt_api): def test_call_method_with_exception(gpt_api):
task = "What is in the image?" task = "What is in the image?"
img_url = img img_url = img
with patch("requests.post", side_effect=Exception("Test Exception")): with patch(
"requests.post", side_effect=Exception("Test Exception")
):
with pytest.raises(Exception): with pytest.raises(Exception):
gpt_api(task, img_url) gpt_api(task, img_url)
@ -119,17 +131,16 @@ def test_call_method_with_exception(gpt_api):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_arun_success(vision_api): async def test_arun_success(vision_api):
expected_response = { expected_response = {
"choices": [{ "choices": [
"message": { {"message": {"content": "This is the model's response."}}
"content": "This is the model's response." ]
}
}]
} }
with patch( with patch(
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock( return_value=AsyncMock(
return_value=expected_response)), json=AsyncMock(return_value=expected_response)
),
) as mock_post: ) as mock_post:
result = await vision_api.arun("What is this?", img) result = await vision_api.arun("What is this?", img)
mock_post.assert_called_once() mock_post.assert_called_once()
@ -149,11 +160,9 @@ async def test_arun_request_error(vision_api):
def test_run_many_success(vision_api): def test_run_many_success(vision_api):
expected_response = { expected_response = {
"choices": [{ "choices": [
"message": { {"message": {"content": "This is the model's response."}}
"content": "This is the model's response." ]
}
}]
} }
with patch( with patch(
"requests.post", "requests.post",
@ -170,7 +179,9 @@ def test_run_many_success(vision_api):
def test_run_many_request_error(vision_api): def test_run_many_request_error(vision_api):
with patch("requests.post", side_effect=RequestException("Request Error")): with patch(
"requests.post", side_effect=RequestException("Request Error")
):
tasks = ["What is this?", "What is that?"] tasks = ["What is this?", "What is that?"]
imgs = [img, img] imgs = [img, img]
with pytest.raises(RequestException): with pytest.raises(RequestException):
@ -182,7 +193,9 @@ async def test_arun_json_decode_error(vision_api):
with patch( with patch(
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock(side_effect=ValueError)), return_value=AsyncMock(
json=AsyncMock(side_effect=ValueError)
),
): ):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)
@ -194,7 +207,9 @@ async def test_arun_api_error(vision_api):
with patch( with patch(
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock(return_value=error_response)), return_value=AsyncMock(
json=AsyncMock(return_value=error_response)
),
): ):
with pytest.raises(Exception, match="API Error"): with pytest.raises(Exception, match="API Error"):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)
@ -206,8 +221,9 @@ async def test_arun_unexpected_response(vision_api):
with patch( with patch(
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock( return_value=AsyncMock(
return_value=unexpected_response)), json=AsyncMock(return_value=unexpected_response)
),
): ):
with pytest.raises(Exception, match="Unexpected response"): with pytest.raises(Exception, match="Unexpected response"):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)

@ -133,7 +133,9 @@ def test_llm_set_repitition_penalty(llm_instance):
def test_llm_set_no_repeat_ngram_size(llm_instance): def test_llm_set_no_repeat_ngram_size(llm_instance):
new_no_repeat_ngram_size = 6 new_no_repeat_ngram_size = 6
llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size) llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size)
assert (llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size) assert (
llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size
)
# Test for setting temperature # Test for setting temperature
@ -183,7 +185,9 @@ def test_llm_set_model_id(llm_instance):
# Test for setting model # Test for setting model
@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained") @patch(
"swarms.models.huggingface.AutoModelForCausalLM.from_pretrained"
)
def test_llm_set_model(mock_model, llm_instance): def test_llm_set_model(mock_model, llm_instance):
mock_model.return_value = "mocked model" mock_model.return_value = "mocked model"
llm_instance.set_model(mock_model) llm_instance.set_model(mock_model)

@ -14,14 +14,19 @@ def mock_pipeline():
@pytest.fixture @pytest.fixture
def pipeline(mock_pipeline): def pipeline(mock_pipeline):
return HuggingfacePipeline("text-generation", return HuggingfacePipeline(
"meta-llama/Llama-2-13b-chat-hf") "text-generation", "meta-llama/Llama-2-13b-chat-hf"
)
def test_init(pipeline, mock_pipeline): def test_init(pipeline, mock_pipeline):
assert pipeline.task_type == "text-generation" assert pipeline.task_type == "text-generation"
assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf" assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf"
assert (pipeline.use_fp8 is True if torch.cuda.is_available() else False) assert (
pipeline.use_fp8 is True
if torch.cuda.is_available()
else False
)
mock_pipeline.assert_called_once_with( mock_pipeline.assert_called_once_with(
"text-generation", "text-generation",
"meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-13b-chat-hf",
@ -46,5 +51,6 @@ def test_run_with_different_task(pipeline, mock_pipeline):
mock_pipeline.return_value = "Generated text" mock_pipeline.return_value = "Generated text"
result = pipeline.run("text-classification", "Hello, world!") result = pipeline.run("text-classification", "Hello, world!")
assert result == "Generated text" assert result == "Generated text"
mock_pipeline.assert_called_once_with("text-classification", mock_pipeline.assert_called_once_with(
"Hello, world!") "text-classification", "Hello, world!"
)

@ -18,7 +18,10 @@ def llm_instance():
# Test for instantiation and attributes # Test for instantiation and attributes
def test_llm_initialization(llm_instance): def test_llm_initialization(llm_instance):
assert (llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha") assert (
llm_instance.model_id
== "NousResearch/Nous-Hermes-2-Vision-Alpha"
)
assert llm_instance.max_length == 500 assert llm_instance.max_length == 500
# ... add more assertions for all default attributes # ... add more assertions for all default attributes
@ -85,11 +88,14 @@ def test_llm_memory_consumption(llm_instance):
) )
def test_llm_initialization_params(model_id, max_length): def test_llm_initialization_params(model_id, max_length):
if max_length: if max_length:
instance = HuggingfaceLLM(model_id=model_id, max_length=max_length) instance = HuggingfaceLLM(
model_id=model_id, max_length=max_length
)
assert instance.max_length == max_length assert instance.max_length == max_length
else: else:
instance = HuggingfaceLLM(model_id=model_id) instance = HuggingfaceLLM(model_id=model_id)
assert (instance.max_length == 500 assert (
instance.max_length == 500
) # Assuming 500 is the default max_length ) # Assuming 500 is the default max_length
@ -138,7 +144,9 @@ def test_llm_run_output_length(mock_run, llm_instance):
# Test the tokenizer handling special tokens correctly # Test the tokenizer handling special tokens correctly
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode")
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode")
def test_llm_tokenizer_special_tokens(mock_decode, mock_encode, llm_instance): def test_llm_tokenizer_special_tokens(
mock_decode, mock_encode, llm_instance
):
mock_encode.return_value = "encoded input with special tokens" mock_encode.return_value = "encoded input with special tokens"
mock_decode.return_value = "decoded output with special tokens" mock_decode.return_value = "decoded output with special tokens"
result = llm_instance.run("test task with special tokens") result = llm_instance.run("test task with special tokens")
@ -164,8 +172,9 @@ def test_llm_response_time(mock_run, llm_instance):
start_time = time.time() start_time = time.time()
llm_instance.run("test task for response time") llm_instance.run("test task for response time")
end_time = time.time() end_time = time.time()
assert (end_time - start_time assert (
< 1) # Assuming the response should be faster than 1 second end_time - start_time < 1
) # Assuming the response should be faster than 1 second
# Test the logging of a warning for long inputs # Test the logging of a warning for long inputs
@ -188,10 +197,13 @@ def test_llm_run_model_exception(mock_generate, llm_instance):
# Test the behavior when GPU is forced but not available # Test the behavior when GPU is forced but not available
@patch("torch.cuda.is_available", return_value=False) @patch("torch.cuda.is_available", return_value=False)
def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): def test_llm_force_gpu_when_unavailable(
mock_is_available, llm_instance
):
with pytest.raises(EnvironmentError): with pytest.raises(EnvironmentError):
llm_instance.set_device( llm_instance.set_device(
"cuda") # Attempt to set CUDA when it's not available "cuda"
) # Attempt to set CUDA when it's not available
# Test for proper cleanup after model use (releasing resources) # Test for proper cleanup after model use (releasing resources)
@ -209,8 +221,9 @@ def test_llm_multilingual_input(mock_run, llm_instance):
mock_run.return_value = "mocked multilingual output" mock_run.return_value = "mocked multilingual output"
multilingual_input = "Bonjour, ceci est un test multilingue." multilingual_input = "Bonjour, ceci est un test multilingue."
result = llm_instance.run(multilingual_input) result = llm_instance.run(multilingual_input)
assert isinstance(result, assert isinstance(
str) # Simple check to ensure output is string type result, str
) # Simple check to ensure output is string type
# Test caching mechanism to prevent re-running the same inputs # Test caching mechanism to prevent re-running the same inputs

@ -13,8 +13,8 @@ from swarms.models.idefics import (
@pytest.fixture @pytest.fixture
def idefics_instance(): def idefics_instance():
with patch( with patch(
"torch.cuda.is_available", "torch.cuda.is_available", return_value=False
return_value=False): # Assuming tests are run on CPU for simplicity ): # Assuming tests are run on CPU for simplicity
instance = Idefics() instance = Idefics()
return instance return instance
@ -46,10 +46,14 @@ def test_init_device(device, expected):
# Test `run` method # Test `run` method
def test_run(idefics_instance): def test_run(idefics_instance):
prompts = [["User: Test"]] prompts = [["User: Test"]]
with patch.object(idefics_instance, with patch.object(
"processor") as mock_processor, patch.object( idefics_instance, "processor"
idefics_instance, "model") as mock_model: ) as mock_processor, patch.object(
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"] mock_processor.batch_decode.return_value = ["Test"]
@ -61,10 +65,14 @@ def test_run(idefics_instance):
# Test `__call__` method (using the same logic as run for simplicity) # Test `__call__` method (using the same logic as run for simplicity)
def test_call(idefics_instance): def test_call(idefics_instance):
prompts = [["User: Test"]] prompts = [["User: Test"]]
with patch.object(idefics_instance, with patch.object(
"processor") as mock_processor, patch.object( idefics_instance, "processor"
idefics_instance, "model") as mock_model: ) as mock_processor, patch.object(
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"] mock_processor.batch_decode.return_value = ["Test"]
@ -77,7 +85,9 @@ def test_call(idefics_instance):
def test_chat(idefics_instance): def test_chat(idefics_instance):
user_input = "User: Hello" user_input = "User: Hello"
response = "Model: Hi there!" response = "Model: Hi there!"
with patch.object(idefics_instance, "run", return_value=[response]): with patch.object(
idefics_instance, "run", return_value=[response]
):
result = idefics_instance.chat(user_input) result = idefics_instance.chat(user_input)
assert result == response assert result == response
@ -87,13 +97,16 @@ def test_chat(idefics_instance):
# Test `set_checkpoint` method # Test `set_checkpoint` method
def test_set_checkpoint(idefics_instance): def test_set_checkpoint(idefics_instance):
new_checkpoint = "new_checkpoint" new_checkpoint = "new_checkpoint"
with patch.object(IdeficsForVisionText2Text, with patch.object(
"from_pretrained") as mock_from_pretrained, patch.object( IdeficsForVisionText2Text, "from_pretrained"
AutoProcessor, "from_pretrained"): ) as mock_from_pretrained, patch.object(
AutoProcessor, "from_pretrained"
):
idefics_instance.set_checkpoint(new_checkpoint) idefics_instance.set_checkpoint(new_checkpoint)
mock_from_pretrained.assert_called_with(new_checkpoint, mock_from_pretrained.assert_called_with(
torch_dtype=torch.bfloat16) new_checkpoint, torch_dtype=torch.bfloat16
)
# Test `set_device` method # Test `set_device` method
@ -130,10 +143,14 @@ def test_run_with_empty_prompts(idefics_instance):
# Test `run` method with batched_mode set to False # Test `run` method with batched_mode set to False
def test_run_batched_mode_false(idefics_instance): def test_run_batched_mode_false(idefics_instance):
task = "User: Test" task = "User: Test"
with patch.object(idefics_instance, with patch.object(
"processor") as mock_processor, patch.object( idefics_instance, "processor"
idefics_instance, "model") as mock_model: ) as mock_processor, patch.object(
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"] mock_processor.batch_decode.return_value = ["Test"]
@ -146,7 +163,9 @@ def test_run_batched_mode_false(idefics_instance):
# Test `run` method with an exception # Test `run` method with an exception
def test_run_with_exception(idefics_instance): def test_run_with_exception(idefics_instance):
task = "User: Test" task = "User: Test"
with patch.object(idefics_instance, "processor") as mock_processor: with patch.object(
idefics_instance, "processor"
) as mock_processor:
mock_processor.side_effect = Exception("Test exception") mock_processor.side_effect = Exception("Test exception")
with pytest.raises(Exception): with pytest.raises(Exception):
idefics_instance.run(task) idefics_instance.run(task)
@ -155,14 +174,17 @@ def test_run_with_exception(idefics_instance):
# Test `set_model_name` method # Test `set_model_name` method
def test_set_model_name(idefics_instance): def test_set_model_name(idefics_instance):
new_model_name = "new_model_name" new_model_name = "new_model_name"
with patch.object(IdeficsForVisionText2Text, with patch.object(
"from_pretrained") as mock_from_pretrained, patch.object( IdeficsForVisionText2Text, "from_pretrained"
AutoProcessor, "from_pretrained"): ) as mock_from_pretrained, patch.object(
AutoProcessor, "from_pretrained"
):
idefics_instance.set_model_name(new_model_name) idefics_instance.set_model_name(new_model_name)
assert idefics_instance.model_name == new_model_name assert idefics_instance.model_name == new_model_name
mock_from_pretrained.assert_called_with(new_model_name, mock_from_pretrained.assert_called_with(
torch_dtype=torch.bfloat16) new_model_name, torch_dtype=torch.bfloat16
)
# Test `__init__` method with device set to None # Test `__init__` method with device set to None

@ -16,7 +16,9 @@ def mock_image_request():
img_data = open(TEST_IMAGE_URL, "rb").read() img_data = open(TEST_IMAGE_URL, "rb").read()
mock_resp = Mock() mock_resp = Mock()
mock_resp.raw = img_data mock_resp.raw = img_data
with patch.object(requests, "get", return_value=mock_resp) as _fixture: with patch.object(
requests, "get", return_value=mock_resp
) as _fixture:
yield _fixture yield _fixture
@ -45,16 +47,18 @@ def test_get_image(mock_image_request):
# Test multimodal grounding # Test multimodal grounding
def test_multimodal_grounding(mock_image_request): def test_multimodal_grounding(mock_image_request):
kosmos = Kosmos() kosmos = Kosmos()
kosmos.multimodal_grounding("Find the red apple in the image.", kosmos.multimodal_grounding(
TEST_IMAGE_URL) "Find the red apple in the image.", TEST_IMAGE_URL
)
# TODO: Validate the result if possible # TODO: Validate the result if possible
# Test referring expression comprehension # Test referring expression comprehension
def test_referring_expression_comprehension(mock_image_request): def test_referring_expression_comprehension(mock_image_request):
kosmos = Kosmos() kosmos = Kosmos()
kosmos.referring_expression_comprehension("Show me the green bottle.", kosmos.referring_expression_comprehension(
TEST_IMAGE_URL) "Show me the green bottle.", TEST_IMAGE_URL
)
# TODO: Validate the result if possible # TODO: Validate the result if possible
@ -89,7 +93,6 @@ IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=fo
# Mock response for requests.get() # Mock response for requests.get()
class MockResponse: class MockResponse:
@staticmethod @staticmethod
def json(): def json():
return {} return {}
@ -108,23 +111,30 @@ def kosmos():
# Mocking the requests.get() method # Mocking the requests.get() method
@pytest.fixture @pytest.fixture
def mock_request_get(monkeypatch): def mock_request_get(monkeypatch):
monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse()) monkeypatch.setattr(
requests, "get", lambda url, **kwargs: MockResponse()
)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_multimodal_grounding(kosmos): def test_multimodal_grounding(kosmos):
kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1) kosmos.multimodal_grounding(
"Find the red apple in the image.", IMG_URL1
)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_comprehension(kosmos): def test_referring_expression_comprehension(kosmos):
kosmos.referring_expression_comprehension("Show me the green bottle.", kosmos.referring_expression_comprehension(
IMG_URL2) "Show me the green bottle.", IMG_URL2
)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_generation(kosmos): def test_referring_expression_generation(kosmos):
kosmos.referring_expression_generation("It is on the table.", IMG_URL3) kosmos.referring_expression_generation(
"It is on the table.", IMG_URL3
)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
@ -144,13 +154,16 @@ def test_grounded_image_captioning_detailed(kosmos):
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_multimodal_grounding_2(kosmos): def test_multimodal_grounding_2(kosmos):
kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2) kosmos.multimodal_grounding(
"Find the yellow fruit in the image.", IMG_URL2
)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_comprehension_2(kosmos): def test_referring_expression_comprehension_2(kosmos):
kosmos.referring_expression_comprehension("Where is the water bottle?", kosmos.referring_expression_comprehension(
IMG_URL3) "Where is the water bottle?", IMG_URL3
)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")

@ -18,7 +18,6 @@ def test_llama_model_loading(llama_caller):
# Test adding and calling custom functions # Test adding and calling custom functions
def test_llama_custom_function(llama_caller): def test_llama_custom_function(llama_caller):
def sample_function(arg1, arg2): def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}" return f"Sample function called with args: {arg1}, {arg2}"
@ -40,11 +39,13 @@ def test_llama_custom_function(llama_caller):
], ],
) )
result = llama_caller.call_function("sample_function", result = llama_caller.call_function(
arg1="arg1_value", "sample_function", arg1="arg1_value", arg2="arg2_value"
arg2="arg2_value") )
assert ( assert (
result == "Sample function called with args: arg1_value, arg2_value") result
== "Sample function called with args: arg1_value, arg2_value"
)
# Test streaming user prompts # Test streaming user prompts
@ -63,7 +64,6 @@ def test_llama_custom_function_not_found(llama_caller):
# Test invalid arguments for custom function # Test invalid arguments for custom function
def test_llama_custom_function_invalid_arguments(llama_caller): def test_llama_custom_function_invalid_arguments(llama_caller):
def sample_function(arg1, arg2): def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}" return f"Sample function called with args: {arg1}, {arg2}"
@ -86,7 +86,9 @@ def test_llama_custom_function_invalid_arguments(llama_caller):
) )
with pytest.raises(TypeError): with pytest.raises(TypeError):
llama_caller.call_function("sample_function", arg1="arg1_value") llama_caller.call_function(
"sample_function", arg1="arg1_value"
)
# Test streaming with custom runtime # Test streaming with custom runtime

@ -21,18 +21,22 @@ def test_mixtral_run(mock_model, mock_tokenizer):
mixtral = Mixtral() mixtral = Mixtral()
mock_tokenizer_instance = MagicMock() mock_tokenizer_instance = MagicMock()
mock_model_instance = MagicMock() mock_model_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance) mock_tokenizer.from_pretrained.return_value = (
mock_tokenizer_instance
)
mock_model.from_pretrained.return_value = mock_model_instance mock_model.from_pretrained.return_value = mock_model_instance
mock_tokenizer_instance.return_tensors = "pt" mock_tokenizer_instance.return_tensors = "pt"
mock_model_instance.generate.return_value = [101, 102, 103] mock_model_instance.generate.return_value = [101, 102, 103]
mock_tokenizer_instance.decode.return_value = "Generated text" mock_tokenizer_instance.decode.return_value = "Generated text"
result = mixtral.run("Test task") result = mixtral.run("Test task")
assert result == "Generated text" assert result == "Generated text"
mock_tokenizer_instance.assert_called_once_with("Test task", mock_tokenizer_instance.assert_called_once_with(
return_tensors="pt") "Test task", return_tensors="pt"
)
mock_model_instance.generate.assert_called_once() mock_model_instance.generate.assert_called_once()
mock_tokenizer_instance.decode.assert_called_once_with( mock_tokenizer_instance.decode.assert_called_once_with(
[101, 102, 103], skip_special_tokens=True) [101, 102, 103], skip_special_tokens=True
)
@patch("swarms.models.mixtral.AutoTokenizer") @patch("swarms.models.mixtral.AutoTokenizer")
@ -41,7 +45,9 @@ def test_mixtral_run_error(mock_model, mock_tokenizer):
mixtral = Mixtral() mixtral = Mixtral()
mock_tokenizer_instance = MagicMock() mock_tokenizer_instance = MagicMock()
mock_model_instance = MagicMock() mock_model_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance) mock_tokenizer.from_pretrained.return_value = (
mock_tokenizer_instance
)
mock_model.from_pretrained.return_value = mock_model_instance mock_model.from_pretrained.return_value = mock_model_instance
mock_tokenizer_instance.return_tensors = "pt" mock_tokenizer_instance.return_tensors = "pt"
mock_model_instance.generate.side_effect = Exception("Test error") mock_model_instance.generate.side_effect = Exception("Test error")

@ -25,10 +25,14 @@ def test_mpt7b_run():
"EleutherAI/gpt-neox-20b", "EleutherAI/gpt-neox-20b",
max_tokens=150, max_tokens=150,
) )
output = mpt.run("generate", "Once upon a time in a land far, far away...") output = mpt.run(
"generate", "Once upon a time in a land far, far away..."
)
assert isinstance(output, str) assert isinstance(output, str)
assert output.startswith("Once upon a time in a land far, far away...") assert output.startswith(
"Once upon a time in a land far, far away..."
)
def test_mpt7b_run_invalid_task(): def test_mpt7b_run_invalid_task():
@ -51,10 +55,14 @@ def test_mpt7b_generate():
"EleutherAI/gpt-neox-20b", "EleutherAI/gpt-neox-20b",
max_tokens=150, max_tokens=150,
) )
output = mpt.generate("Once upon a time in a land far, far away...") output = mpt.generate(
"Once upon a time in a land far, far away..."
)
assert isinstance(output, str) assert isinstance(output, str)
assert output.startswith("Once upon a time in a land far, far away...") assert output.startswith(
"Once upon a time in a land far, far away..."
)
def test_mpt7b_batch_generate(): def test_mpt7b_batch_generate():

@ -43,7 +43,9 @@ def test_model_initialization(setup_nougat):
"cuda_available, expected_device", "cuda_available, expected_device",
[(True, "cuda"), (False, "cpu")], [(True, "cuda"), (False, "cpu")],
) )
def test_device_initialization(cuda_available, expected_device, monkeypatch): def test_device_initialization(
cuda_available, expected_device, monkeypatch
):
monkeypatch.setattr( monkeypatch.setattr(
torch, torch,
"cuda", "cuda",
@ -72,7 +74,9 @@ def test_get_image_invalid_path(setup_nougat):
(10, 50), (10, 50),
], ],
) )
def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens): def test_model_call_with_diff_params(
setup_nougat, min_len, max_tokens
):
setup_nougat.min_length = min_len setup_nougat.min_length = min_len
setup_nougat.max_new_tokens = max_tokens setup_nougat.max_new_tokens = max_tokens
@ -118,7 +122,8 @@ def test_nougat_with_sample_image_1(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@ -135,7 +140,8 @@ def test_nougat_min_length_param(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@ -146,7 +152,8 @@ def test_nougat_max_new_tokens_param(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@ -157,13 +164,16 @@ def test_nougat_different_model_path(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@pytest.mark.usefixtures("mock_processor_and_model") @pytest.mark.usefixtures("mock_processor_and_model")
def test_nougat_bad_image_path(setup_nougat): def test_nougat_bad_image_path(setup_nougat):
with pytest.raises(Exception): # Adjust the exception type accordingly. with pytest.raises(
Exception
): # Adjust the exception type accordingly.
setup_nougat("bad_image_path.png") setup_nougat("bad_image_path.png")
@ -173,7 +183,8 @@ def test_nougat_image_large_size(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@ -183,7 +194,8 @@ def test_nougat_image_small_size(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@ -193,7 +205,8 @@ def test_nougat_image_varied_content(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", "https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)
@ -203,5 +216,6 @@ def test_nougat_image_with_metadata(setup_nougat):
os.path.join( os.path.join(
"sample_images", "sample_images",
"https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", "https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D",
)) )
)
assert isinstance(result, str) assert isinstance(result, str)

@ -4,7 +4,8 @@ from swarms.models.qwen import QwenVLMultiModal
def test_post_init(): def test_post_init():
with patch("swarms.models.qwen.AutoTokenizer.from_pretrained" with patch(
"swarms.models.qwen.AutoTokenizer.from_pretrained"
) as mock_tokenizer, patch( ) as mock_tokenizer, patch(
"swarms.models.qwen.AutoModelForCausalLM.from_pretrained" "swarms.models.qwen.AutoModelForCausalLM.from_pretrained"
) as mock_model: ) as mock_model:
@ -12,8 +13,9 @@ def test_post_init():
mock_model.return_value = Mock() mock_model.return_value = Mock()
model = QwenVLMultiModal() model = QwenVLMultiModal()
mock_tokenizer.assert_called_once_with(model.model_name, mock_tokenizer.assert_called_once_with(
trust_remote_code=True) model.model_name, trust_remote_code=True
)
mock_model.assert_called_once_with( mock_model.assert_called_once_with(
model.model_name, model.model_name,
device_map=model.device, device_map=model.device,
@ -25,29 +27,35 @@ def test_run():
with patch( with patch(
"swarms.models.qwen.AutoTokenizer.from_list_format" "swarms.models.qwen.AutoTokenizer.from_list_format"
) as mock_format, patch( ) as mock_format, patch(
"swarms.models.qwen.AutoTokenizer.__call__") as mock_call, patch( "swarms.models.qwen.AutoTokenizer.__call__"
) as mock_call, patch(
"swarms.models.qwen.AutoModelForCausalLM.generate" "swarms.models.qwen.AutoModelForCausalLM.generate"
) as mock_generate, patch( ) as mock_generate, patch(
"swarms.models.qwen.AutoTokenizer.decode") as mock_decode: "swarms.models.qwen.AutoTokenizer.decode"
) as mock_decode:
mock_format.return_value = Mock() mock_format.return_value = Mock()
mock_call.return_value = Mock() mock_call.return_value = Mock()
mock_generate.return_value = Mock() mock_generate.return_value = Mock()
mock_decode.return_value = "response" mock_decode.return_value = "response"
model = QwenVLMultiModal() model = QwenVLMultiModal()
response = model.run("Hello, how are you?", response = model.run(
"https://example.com/image.jpg") "Hello, how are you?", "https://example.com/image.jpg"
)
assert response == "response" assert response == "response"
def test_chat(): def test_chat():
with patch("swarms.models.qwen.AutoModelForCausalLM.chat") as mock_chat: with patch(
"swarms.models.qwen.AutoModelForCausalLM.chat"
) as mock_chat:
mock_chat.return_value = ("response", ["history"]) mock_chat.return_value = ("response", ["history"])
model = QwenVLMultiModal() model = QwenVLMultiModal()
response, history = model.chat("Hello, how are you?", response, history = model.chat(
"https://example.com/image.jpg") "Hello, how are you?", "https://example.com/image.jpg"
)
assert response == "response" assert response == "response"
assert history == ["history"] assert history == ["history"]

@ -16,11 +16,16 @@ def speecht5_model():
def test_speecht5_init(speecht5_model): def test_speecht5_init(speecht5_model):
assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) assert isinstance(
speecht5_model.processor, SpeechT5.processor.__class__
)
assert isinstance(speecht5_model.model, SpeechT5.model.__class__) assert isinstance(speecht5_model.model, SpeechT5.model.__class__)
assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) assert isinstance(
assert isinstance(speecht5_model.embeddings_dataset, speecht5_model.vocoder, SpeechT5.vocoder.__class__
torch.utils.data.Dataset) )
assert isinstance(
speecht5_model.embeddings_dataset, torch.utils.data.Dataset
)
def test_speecht5_call(speecht5_model): def test_speecht5_call(speecht5_model):
@ -44,7 +49,10 @@ def test_speecht5_set_model(speecht5_model):
speecht5_model.set_model(new_model_name) speecht5_model.set_model(new_model_name)
assert speecht5_model.model_name == new_model_name assert speecht5_model.model_name == new_model_name
assert speecht5_model.processor.model_name == new_model_name assert speecht5_model.processor.model_name == new_model_name
assert (speecht5_model.model.config.model_name_or_path == new_model_name) assert (
speecht5_model.model.config.model_name_or_path
== new_model_name
)
speecht5_model.set_model(old_model_name) # Restore original model speecht5_model.set_model(old_model_name) # Restore original model
@ -54,8 +62,12 @@ def test_speecht5_set_vocoder(speecht5_model):
speecht5_model.set_vocoder(new_vocoder_name) speecht5_model.set_vocoder(new_vocoder_name)
assert speecht5_model.vocoder_name == new_vocoder_name assert speecht5_model.vocoder_name == new_vocoder_name
assert ( assert (
speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name) speecht5_model.vocoder.config.model_name_or_path
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder == new_vocoder_name
)
speecht5_model.set_vocoder(
old_vocoder_name
) # Restore original vocoder
def test_speecht5_set_embeddings_dataset(speecht5_model): def test_speecht5_set_embeddings_dataset(speecht5_model):
@ -63,10 +75,12 @@ def test_speecht5_set_embeddings_dataset(speecht5_model):
new_dataset_name = "Matthijs/cmu-arctic-xvectors-test" new_dataset_name = "Matthijs/cmu-arctic-xvectors-test"
speecht5_model.set_embeddings_dataset(new_dataset_name) speecht5_model.set_embeddings_dataset(new_dataset_name)
assert speecht5_model.dataset_name == new_dataset_name assert speecht5_model.dataset_name == new_dataset_name
assert isinstance(speecht5_model.embeddings_dataset, assert isinstance(
torch.utils.data.Dataset) speecht5_model.embeddings_dataset, torch.utils.data.Dataset
)
speecht5_model.set_embeddings_dataset( speecht5_model.set_embeddings_dataset(
old_dataset_name) # Restore original dataset old_dataset_name
) # Restore original dataset
def test_speecht5_get_sampling_rate(speecht5_model): def test_speecht5_get_sampling_rate(speecht5_model):
@ -98,7 +112,9 @@ def test_speecht5_change_dataset_split(speecht5_model):
def test_speecht5_load_custom_embedding(speecht5_model): def test_speecht5_load_custom_embedding(speecht5_model):
xvector = [0.1, 0.2, 0.3, 0.4, 0.5] xvector = [0.1, 0.2, 0.3, 0.4, 0.5]
embedding = speecht5_model.load_custom_embedding(xvector) embedding = speecht5_model.load_custom_embedding(xvector)
assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))) assert torch.all(
torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))
)
def test_speecht5_with_different_speakers(speecht5_model): def test_speecht5_with_different_speakers(speecht5_model):
@ -109,7 +125,9 @@ def test_speecht5_with_different_speakers(speecht5_model):
assert isinstance(speech, torch.Tensor) assert isinstance(speech, torch.Tensor)
def test_speecht5_save_speech_with_different_extensions(speecht5_model,): def test_speecht5_save_speech_with_different_extensions(
speecht5_model,
):
text = "Hello, how are you?" text = "Hello, how are you?"
speech = speecht5_model(text) speech = speecht5_model(text)
extensions = [".wav", ".flac"] extensions = [".wav", ".flac"]
@ -144,4 +162,6 @@ def test_speecht5_change_vocoder_model(speecht5_model):
speecht5_model.set_vocoder(new_vocoder_name) speecht5_model.set_vocoder(new_vocoder_name)
speech = speecht5_model(text) speech = speecht5_model(text)
assert isinstance(speech, torch.Tensor) assert isinstance(speech, torch.Tensor)
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder speecht5_model.set_vocoder(
old_vocoder_name
) # Restore original vocoder

@ -21,30 +21,36 @@ def test_ssd1b_call(ssd1b_model):
image_url = ssd1b_model(task, neg_prompt) image_url = ssd1b_model(task, neg_prompt)
assert isinstance(image_url, str) assert isinstance(image_url, str)
assert image_url.startswith( assert image_url.startswith(
"https://") # Assuming it starts with "https://" "https://"
) # Assuming it starts with "https://"
# Add more tests for various aspects of the class and methods # Add more tests for various aspects of the class and methods
# Example of a parameterized test for different tasks # Example of a parameterized test for different tasks
@pytest.mark.parametrize("task", @pytest.mark.parametrize(
["A painting of a cat", "A painting of a tree"]) "task", ["A painting of a cat", "A painting of a tree"]
)
def test_ssd1b_parameterized_task(ssd1b_model, task): def test_ssd1b_parameterized_task(ssd1b_model, task):
image_url = ssd1b_model(task) image_url = ssd1b_model(task)
assert isinstance(image_url, str) assert isinstance(image_url, str)
assert image_url.startswith( assert image_url.startswith(
"https://") # Assuming it starts with "https://" "https://"
) # Assuming it starts with "https://"
# Example of a test using mocks to isolate units of code # Example of a test using mocks to isolate units of code
def test_ssd1b_with_mock(ssd1b_model, mocker): def test_ssd1b_with_mock(ssd1b_model, mocker):
mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline mocker.patch(
"your_module.StableDiffusionXLPipeline"
) # Mock the pipeline
task = "A painting of a cat" task = "A painting of a cat"
image_url = ssd1b_model(task) image_url = ssd1b_model(task)
assert isinstance(image_url, str) assert isinstance(image_url, str)
assert image_url.startswith( assert image_url.startswith(
"https://") # Assuming it starts with "https://" "https://"
) # Assuming it starts with "https://"
def test_ssd1b_call_with_cache(ssd1b_model): def test_ssd1b_call_with_cache(ssd1b_model):
@ -62,7 +68,8 @@ def test_ssd1b_invalid_task(ssd1b_model):
def test_ssd1b_failed_api_call(ssd1b_model, mocker): def test_ssd1b_failed_api_call(ssd1b_model, mocker):
mocker.patch("your_module.StableDiffusionXLPipeline" mocker.patch(
"your_module.StableDiffusionXLPipeline"
) # Mock the pipeline to raise an exception ) # Mock the pipeline to raise an exception
task = "A painting of a cat" task = "A painting of a cat"
with pytest.raises(Exception): with pytest.raises(Exception):

@ -19,16 +19,18 @@ def test_timm_model_init():
def test_timm_model_call(): def test_timm_model_call():
with patch("swarms.models.timm.create_model") as mock_create_model: with patch(
"swarms.models.timm.create_model"
) as mock_create_model:
model_name = "resnet18" model_name = "resnet18"
pretrained = True pretrained = True
in_chans = 3 in_chans = 3
timm_model = TimmModel(model_name, pretrained, in_chans) timm_model = TimmModel(model_name, pretrained, in_chans)
task = torch.rand(1, in_chans, 224, 224) task = torch.rand(1, in_chans, 224, 224)
result = timm_model(task) result = timm_model(task)
mock_create_model.assert_called_once_with(model_name, mock_create_model.assert_called_once_with(
pretrained=pretrained, model_name, pretrained=pretrained, in_chans=in_chans
in_chans=in_chans) )
assert result == mock_create_model.return_value(task) assert result == mock_create_model.return_value(task)

@ -22,14 +22,17 @@ def test_create_model(sample_model_info):
def test_call(sample_model_info): def test_call(sample_model_info):
model_handler = TimmModel() model_handler = TimmModel()
input_tensor = torch.randn(1, 3, 224, 224) input_tensor = torch.randn(1, 3, 224, 224)
output_shape = model_handler.__call__(sample_model_info, input_tensor) output_shape = model_handler.__call__(
sample_model_info, input_tensor
)
assert isinstance(output_shape, torch.Size) assert isinstance(output_shape, torch.Size)
def test_get_supported_models_mock(): def test_get_supported_models_mock():
model_handler = TimmModel() model_handler = TimmModel()
model_handler._get_supported_models = Mock( model_handler._get_supported_models = Mock(
return_value=["resnet18", "resnet50"]) return_value=["resnet18", "resnet50"]
)
supported_models = model_handler._get_supported_models() supported_models = model_handler._get_supported_models()
assert supported_models == ["resnet18", "resnet50"] assert supported_models == ["resnet18", "resnet50"]

@ -55,11 +55,7 @@ def test_init_custom_params(mock_api_key):
def test_run_success(mock_post, mock_api_key): def test_run_success(mock_post, mock_api_key):
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = { mock_response.json.return_value = {
"choices": [{ "choices": [{"message": {"content": "Generated response"}}]
"message": {
"content": "Generated response"
}
}]
} }
mock_post.return_value = mock_response mock_post.return_value = mock_response
@ -73,7 +69,8 @@ def test_run_success(mock_post, mock_api_key):
@patch("swarms.models.together_model.requests.post") @patch("swarms.models.together_model.requests.post")
def test_run_failure(mock_post, mock_api_key): def test_run_failure(mock_post, mock_api_key):
mock_post.side_effect = requests.exceptions.RequestException( mock_post.side_effect = requests.exceptions.RequestException(
"Request failed") "Request failed"
)
model = TogetherLLM() model = TogetherLLM()
task = "What is the color of the object?" task = "What is the color of the object?"
@ -92,7 +89,9 @@ def test_run_with_logging_enabled(caplog, mock_api_key):
assert "Sending request to" in caplog.text assert "Sending request to" in caplog.text
@pytest.mark.parametrize("invalid_input", [None, 123, ["list", "of", "items"]]) @pytest.mark.parametrize(
"invalid_input", [None, 123, ["list", "of", "items"]]
)
def test_invalid_task_input(invalid_input, mock_api_key): def test_invalid_task_input(invalid_input, mock_api_key):
model = TogetherLLM() model = TogetherLLM()
response = model.run(invalid_input) response = model.run(invalid_input)
@ -104,11 +103,7 @@ def test_invalid_task_input(invalid_input, mock_api_key):
def test_run_streaming_enabled(mock_post, mock_api_key): def test_run_streaming_enabled(mock_post, mock_api_key):
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = { mock_response.json.return_value = {
"choices": [{ "choices": [{"message": {"content": "Generated response"}}]
"message": {
"content": "Generated response"
}
}]
} }
mock_post.return_value = mock_response mock_post.return_value = mock_response

@ -20,7 +20,9 @@ def test_ultralytics_call():
args = (1, 2, 3) args = (1, 2, 3)
kwargs = {"a": "A", "b": "B"} kwargs = {"a": "A", "b": "B"}
result = ultralytics(task, *args, **kwargs) result = ultralytics(task, *args, **kwargs)
mock_yolo.return_value.assert_called_once_with(task, *args, **kwargs) mock_yolo.return_value.assert_called_once_with(
task, *args, **kwargs
)
assert result == mock_yolo.return_value.return_value assert result == mock_yolo.return_value.return_value

@ -21,13 +21,17 @@ def test_vilt_initialization(vilt_instance):
# 2. Test Model Predictions # 2. Test Model Predictions
@patch.object(requests, "get") @patch.object(requests, "get")
@patch.object(Image, "open") @patch.object(Image, "open")
def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): def test_vilt_prediction(
mock_image_open, mock_requests_get, vilt_instance
):
mock_image = Mock() mock_image = Mock()
mock_image_open.return_value = mock_image mock_image_open.return_value = mock_image
mock_requests_get.return_value.raw = Mock() mock_requests_get.return_value.raw = Mock()
# It's a mock response, so no real answer expected # It's a mock response, so no real answer expected
with pytest.raises(Exception): # Ensure exception is more specific with pytest.raises(
Exception
): # Ensure exception is more specific
vilt_instance( vilt_instance(
"What is this image", "What is this image",
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
@ -62,7 +66,9 @@ def test_vilt_network_exception(vilt_instance):
], ],
) )
def test_vilt_various_inputs(text, image_url, vilt_instance): def test_vilt_various_inputs(text, image_url, vilt_instance):
with pytest.raises(Exception): # Again, ensure exception is more specific with pytest.raises(
Exception
): # Again, ensure exception is more specific
vilt_instance(text, image_url) vilt_instance(text, image_url)

@ -32,7 +32,9 @@ def test_yi34b_generate_text_with_length(yi34b_model, max_length):
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): def test_yi34b_generate_text_with_temperature(
yi34b_model, temperature
):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, temperature=temperature) generated_text = yi34b_model(prompt, temperature=temperature)
assert isinstance(generated_text, str) assert isinstance(generated_text, str)
@ -40,24 +42,27 @@ def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model): def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
prompt = None # Invalid prompt prompt = None # Invalid prompt
with pytest.raises(ValueError, with pytest.raises(
match="Input prompt must be a non-empty string"): ValueError, match="Input prompt must be a non-empty string"
):
yi34b_model(prompt) yi34b_model(prompt)
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model): def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
max_length = -1 # Invalid max_length max_length = -1 # Invalid max_length
with pytest.raises(ValueError, with pytest.raises(
match="max_length must be a positive integer"): ValueError, match="max_length must be a positive integer"
):
yi34b_model(prompt, max_length=max_length) yi34b_model(prompt, max_length=max_length)
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model): def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
temperature = 2.0 # Invalid temperature temperature = 2.0 # Invalid temperature
with pytest.raises(ValueError, with pytest.raises(
match="temperature must be between 0.01 and 1.0"): ValueError, match="temperature must be between 0.01 and 1.0"
):
yi34b_model(prompt, temperature=temperature) yi34b_model(prompt, temperature=temperature)
@ -78,27 +83,35 @@ def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
top_k = -1 # Invalid top_k top_k = -1 # Invalid top_k
with pytest.raises(ValueError, with pytest.raises(
match="top_k must be a non-negative integer"): ValueError, match="top_k must be a non-negative integer"
):
yi34b_model(prompt, top_k=top_k) yi34b_model(prompt, top_k=top_k)
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
top_p = 1.5 # Invalid top_p top_p = 1.5 # Invalid top_p
with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"): with pytest.raises(
ValueError, match="top_p must be between 0.0 and 1.0"
):
yi34b_model(prompt, top_p=top_p) yi34b_model(prompt, top_p=top_p)
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5]) @pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, def test_yi34b_generate_text_with_repitition_penalty(
repitition_penalty): yi34b_model, repitition_penalty
):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) generated_text = yi34b_model(
prompt, repitition_penalty=repitition_penalty
)
assert isinstance(generated_text, str) assert isinstance(generated_text, str)
def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model,): def test_yi34b_generate_text_with_invalid_repitition_penalty(
yi34b_model,
):
prompt = "There's a place where time stands still." prompt = "There's a place where time stands still."
repitition_penalty = 0.0 # Invalid repitition_penalty repitition_penalty = 0.0 # Invalid repitition_penalty
with pytest.raises( with pytest.raises(

@ -25,11 +25,16 @@ def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV() zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock() mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
mock_pipeline_instance.enable_vae_slicing.assert_called_once() mock_pipeline_instance.enable_vae_slicing.assert_called_once()
mock_pipeline_instance.enable_forward_chunking.assert_called_once_with( mock_pipeline_instance.enable_forward_chunking.assert_called_once_with(
chunk_size=1, dim=1) chunk_size=1, dim=1
)
result = zeroscope.forward("Test task") result = zeroscope.forward("Test task")
assert result == "Generated frames" assert result == "Generated frames"
mock_pipeline_instance.assert_called_once_with( mock_pipeline_instance.assert_called_once_with(
@ -46,8 +51,12 @@ def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV() zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock() mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
mock_pipeline_instance.side_effect = Exception("Test error") mock_pipeline_instance.side_effect = Exception("Test error")
with pytest.raises(Exception, match="Test error"): with pytest.raises(Exception, match="Test error"):
zeroscope.forward("Test task") zeroscope.forward("Test task")
@ -58,8 +67,12 @@ def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV() zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock() mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
result = zeroscope.__call__("Test task") result = zeroscope.__call__("Test task")
assert result == "Generated frames" assert result == "Generated frames"
mock_pipeline_instance.assert_called_once_with( mock_pipeline_instance.assert_called_once_with(
@ -76,8 +89,12 @@ def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV() zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock() mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
mock_pipeline_instance.side_effect = Exception("Test error") mock_pipeline_instance.side_effect = Exception("Test error")
with pytest.raises(Exception, match="Test error"): with pytest.raises(Exception, match="Test error"):
zeroscope.__call__("Test task") zeroscope.__call__("Test task")
@ -88,8 +105,12 @@ def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV() zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock() mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance) mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames") mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
result = zeroscope.save_video_path("Test video path") result = zeroscope.save_video_path("Test video path")
assert result == "Test video path" assert result == "Test video path"
mock_pipeline_instance.assert_called_once_with( mock_pipeline_instance.assert_called_once_with(

@ -18,7 +18,9 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
# Mocks and Fixtures # Mocks and Fixtures
@pytest.fixture @pytest.fixture
def mocked_llm(): def mocked_llm():
return OpenAIChat(openai_api_key=openai_api_key,) return OpenAIChat(
openai_api_key=openai_api_key,
)
@pytest.fixture @pytest.fixture
@ -63,12 +65,15 @@ def test_provide_feedback(basic_flow):
@patch("time.sleep", return_value=None) # to speed up tests @patch("time.sleep", return_value=None) # to speed up tests
def test_run_without_stopping_condition(mocked_sleep, basic_flow): def test_run_without_stopping_condition(mocked_sleep, basic_flow):
response = basic_flow.run("Test task") response = basic_flow.run("Test task")
assert (response == "Test task" assert (
response == "Test task"
) # since our mocked llm doesn't modify the response ) # since our mocked llm doesn't modify the response
@patch("time.sleep", return_value=None) # to speed up tests @patch("time.sleep", return_value=None) # to speed up tests
def test_run_with_stopping_condition(mocked_sleep, flow_with_condition): def test_run_with_stopping_condition(
mocked_sleep, flow_with_condition
):
response = flow_with_condition.run("Stop") response = flow_with_condition.run("Stop")
assert response == "Stop" assert response == "Stop"
@ -108,7 +113,6 @@ def test_env_variable_handling(monkeypatch):
# Test initializing the agent with different stopping conditions # Test initializing the agent with different stopping conditions
def test_flow_with_custom_stopping_condition(mocked_llm): def test_flow_with_custom_stopping_condition(mocked_llm):
def stopping_condition(x): def stopping_condition(x):
return "terminate" in x.lower() return "terminate" in x.lower()
@ -129,7 +133,9 @@ def test_flow_call(basic_flow):
# Test formatting the prompt # Test formatting the prompt
def test_format_prompt(basic_flow): def test_format_prompt(basic_flow):
formatted_prompt = basic_flow.format_prompt("Hello {name}", name="John") formatted_prompt = basic_flow.format_prompt(
"Hello {name}", name="John"
)
assert formatted_prompt == "Hello John" assert formatted_prompt == "Hello John"
@ -158,15 +164,9 @@ def test_interactive_mode(basic_flow):
# Test bulk run with varied inputs # Test bulk run with varied inputs
def test_bulk_run_varied_inputs(basic_flow): def test_bulk_run_varied_inputs(basic_flow):
inputs = [ inputs = [
{ {"task": "Test1"},
"task": "Test1" {"task": "Test2"},
}, {"task": "Stop now"},
{
"task": "Test2"
},
{
"task": "Stop now"
},
] ]
responses = basic_flow.bulk_run(inputs) responses = basic_flow.bulk_run(inputs)
assert responses == ["Test1", "Test2", "Stop now"] assert responses == ["Test1", "Test2", "Stop now"]
@ -191,9 +191,12 @@ def test_save_different_memory(basic_flow, tmp_path):
# Test the stopping condition check # Test the stopping condition check
def test_check_stopping_condition(flow_with_condition): def test_check_stopping_condition(flow_with_condition):
assert flow_with_condition._check_stopping_condition("Stop this process") assert flow_with_condition._check_stopping_condition(
"Stop this process"
)
assert not flow_with_condition._check_stopping_condition( assert not flow_with_condition._check_stopping_condition(
"Continue the task") "Continue the task"
)
# Test without providing max loops (default value should be 5) # Test without providing max loops (default value should be 5)
@ -249,7 +252,9 @@ def test_different_retry_intervals(mocked_sleep, basic_flow):
# Test invoking the agent with additional kwargs # Test invoking the agent with additional kwargs
@patch("time.sleep", return_value=None) @patch("time.sleep", return_value=None)
def test_flow_call_with_kwargs(mocked_sleep, basic_flow): def test_flow_call_with_kwargs(mocked_sleep, basic_flow):
response = basic_flow("Test call", param1="value1", param2="value2") response = basic_flow(
"Test call", param1="value1", param2="value2"
)
assert response == "Test call" assert response == "Test call"
@ -284,7 +289,9 @@ def test_stopping_token_in_response(mocked_sleep, basic_flow):
def flow_instance(): def flow_instance():
# Create an instance of the Agent class with required parameters for testing # Create an instance of the Agent class with required parameters for testing
# You may need to adjust this based on your actual class initialization # You may need to adjust this based on your actual class initialization
llm = OpenAIChat(openai_api_key=openai_api_key,) llm = OpenAIChat(
openai_api_key=openai_api_key,
)
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=5, max_loops=5,
@ -331,7 +338,9 @@ def test_flow_autosave(flow_instance):
def test_flow_response_filtering(flow_instance): def test_flow_response_filtering(flow_instance):
# Test the response filtering functionality # Test the response filtering functionality
flow_instance.add_response_filter("filter_this") flow_instance.add_response_filter("filter_this")
response = flow_instance.filtered_run("This message should filter_this") response = flow_instance.filtered_run(
"This message should filter_this"
)
assert "filter_this" not in response assert "filter_this" not in response
@ -391,8 +400,11 @@ def test_flow_response_length(flow_instance):
# Test checking the length of the response # Test checking the length of the response
response = flow_instance.run( response = flow_instance.run(
"Generate a 10,000 word long blog on mental clarity and the" "Generate a 10,000 word long blog on mental clarity and the"
" benefits of meditation.") " benefits of meditation."
assert (len(response) > flow_instance.get_response_length_threshold()) )
assert (
len(response) > flow_instance.get_response_length_threshold()
)
def test_flow_set_response_length_threshold(flow_instance): def test_flow_set_response_length_threshold(flow_instance):
@ -481,7 +493,9 @@ def test_flow_get_conversation_log(flow_instance):
flow_instance.run("Message 1") flow_instance.run("Message 1")
flow_instance.run("Message 2") flow_instance.run("Message 2")
conversation_log = flow_instance.get_conversation_log() conversation_log = flow_instance.get_conversation_log()
assert (len(conversation_log) == 4) # Including system and user messages assert (
len(conversation_log) == 4
) # Including system and user messages
def test_flow_clear_conversation_log(flow_instance): def test_flow_clear_conversation_log(flow_instance):
@ -565,18 +579,37 @@ def test_flow_rollback(flow_instance):
flow_instance.change_prompt("New prompt") flow_instance.change_prompt("New prompt")
flow_instance.get_state() flow_instance.get_state()
flow_instance.rollback_to_state(state1) flow_instance.rollback_to_state(state1)
assert (flow_instance.get_current_prompt() == state1["current_prompt"]) assert (
flow_instance.get_current_prompt() == state1["current_prompt"]
)
assert flow_instance.get_instructions() == state1["instructions"] assert flow_instance.get_instructions() == state1["instructions"]
assert (flow_instance.get_user_messages() == state1["user_messages"]) assert (
assert (flow_instance.get_response_history() == state1["response_history"]) flow_instance.get_user_messages() == state1["user_messages"]
assert (flow_instance.get_conversation_log() == state1["conversation_log"]) )
assert (flow_instance.is_dynamic_pacing_enabled() == assert (
state1["dynamic_pacing_enabled"]) flow_instance.get_response_history()
assert (flow_instance.get_response_length_threshold() == == state1["response_history"]
state1["response_length_threshold"]) )
assert (flow_instance.get_response_filters() == state1["response_filters"]) assert (
flow_instance.get_conversation_log()
== state1["conversation_log"]
)
assert (
flow_instance.is_dynamic_pacing_enabled()
== state1["dynamic_pacing_enabled"]
)
assert (
flow_instance.get_response_length_threshold()
== state1["response_length_threshold"]
)
assert (
flow_instance.get_response_filters()
== state1["response_filters"]
)
assert flow_instance.get_max_loops() == state1["max_loops"] assert flow_instance.get_max_loops() == state1["max_loops"]
assert (flow_instance.get_autosave_path() == state1["autosave_path"]) assert (
flow_instance.get_autosave_path() == state1["autosave_path"]
)
assert flow_instance.get_state() == state1 assert flow_instance.get_state() == state1
@ -585,7 +618,8 @@ def test_flow_contextual_intent(flow_instance):
flow_instance.add_context("location", "New York") flow_instance.add_context("location", "New York")
flow_instance.add_context("time", "tomorrow") flow_instance.add_context("time", "tomorrow")
response = flow_instance.run( response = flow_instance.run(
"What's the weather like in {location} at {time}?") "What's the weather like in {location} at {time}?"
)
assert "New York" in response assert "New York" in response
assert "tomorrow" in response assert "tomorrow" in response
@ -593,9 +627,13 @@ def test_flow_contextual_intent(flow_instance):
def test_flow_contextual_intent_override(flow_instance): def test_flow_contextual_intent_override(flow_instance):
# Test contextual intent override # Test contextual intent override
flow_instance.add_context("location", "New York") flow_instance.add_context("location", "New York")
response1 = flow_instance.run("What's the weather like in {location}?") response1 = flow_instance.run(
"What's the weather like in {location}?"
)
flow_instance.add_context("location", "Los Angeles") flow_instance.add_context("location", "Los Angeles")
response2 = flow_instance.run("What's the weather like in {location}?") response2 = flow_instance.run(
"What's the weather like in {location}?"
)
assert "New York" in response1 assert "New York" in response1
assert "Los Angeles" in response2 assert "Los Angeles" in response2
@ -603,9 +641,13 @@ def test_flow_contextual_intent_override(flow_instance):
def test_flow_contextual_intent_reset(flow_instance): def test_flow_contextual_intent_reset(flow_instance):
# Test resetting contextual intent # Test resetting contextual intent
flow_instance.add_context("location", "New York") flow_instance.add_context("location", "New York")
response1 = flow_instance.run("What's the weather like in {location}?") response1 = flow_instance.run(
"What's the weather like in {location}?"
)
flow_instance.reset_context() flow_instance.reset_context()
response2 = flow_instance.run("What's the weather like in {location}?") response2 = flow_instance.run(
"What's the weather like in {location}?"
)
assert "New York" in response1 assert "New York" in response1
assert "New York" in response2 assert "New York" in response2
@ -630,7 +672,9 @@ def test_flow_non_interruptible(flow_instance):
def test_flow_timeout(flow_instance): def test_flow_timeout(flow_instance):
# Test conversation timeout # Test conversation timeout
flow_instance.timeout = 60 # Set a timeout of 60 seconds flow_instance.timeout = 60 # Set a timeout of 60 seconds
response = flow_instance.run("This should take some time to respond.") response = flow_instance.run(
"This should take some time to respond."
)
assert "Timed out" in response assert "Timed out" in response
assert flow_instance.is_timed_out() is True assert flow_instance.is_timed_out() is True
@ -679,14 +723,20 @@ def test_flow_save_and_load_conversation(flow_instance):
def test_flow_inject_custom_system_message(flow_instance): def test_flow_inject_custom_system_message(flow_instance):
# Test injecting a custom system message into the conversation # Test injecting a custom system message into the conversation
flow_instance.inject_custom_system_message("Custom system message") flow_instance.inject_custom_system_message(
assert ("Custom system message" in flow_instance.get_message_history()) "Custom system message"
)
assert (
"Custom system message" in flow_instance.get_message_history()
)
def test_flow_inject_custom_user_message(flow_instance): def test_flow_inject_custom_user_message(flow_instance):
# Test injecting a custom user message into the conversation # Test injecting a custom user message into the conversation
flow_instance.inject_custom_user_message("Custom user message") flow_instance.inject_custom_user_message("Custom user message")
assert ("Custom user message" in flow_instance.get_message_history()) assert (
"Custom user message" in flow_instance.get_message_history()
)
def test_flow_inject_custom_response(flow_instance): def test_flow_inject_custom_response(flow_instance):
@ -697,28 +747,45 @@ def test_flow_inject_custom_response(flow_instance):
def test_flow_clear_injected_messages(flow_instance): def test_flow_clear_injected_messages(flow_instance):
# Test clearing injected messages from the conversation # Test clearing injected messages from the conversation
flow_instance.inject_custom_system_message("Custom system message") flow_instance.inject_custom_system_message(
"Custom system message"
)
flow_instance.inject_custom_user_message("Custom user message") flow_instance.inject_custom_user_message("Custom user message")
flow_instance.inject_custom_response("Custom response") flow_instance.inject_custom_response("Custom response")
flow_instance.clear_injected_messages() flow_instance.clear_injected_messages()
assert ("Custom system message" not in flow_instance.get_message_history()) assert (
assert ("Custom user message" not in flow_instance.get_message_history()) "Custom system message"
assert ("Custom response" not in flow_instance.get_message_history()) not in flow_instance.get_message_history()
)
assert (
"Custom user message"
not in flow_instance.get_message_history()
)
assert (
"Custom response" not in flow_instance.get_message_history()
)
def test_flow_disable_message_history(flow_instance): def test_flow_disable_message_history(flow_instance):
# Test disabling message history recording # Test disabling message history recording
flow_instance.disable_message_history() flow_instance.disable_message_history()
response = flow_instance.run( response = flow_instance.run(
"This message should not be recorded in history.") "This message should not be recorded in history."
assert ("This message should not be recorded in history." in response) )
assert (len(flow_instance.get_message_history()) == 0) # History is empty assert (
"This message should not be recorded in history." in response
)
assert (
len(flow_instance.get_message_history()) == 0
) # History is empty
def test_flow_enable_message_history(flow_instance): def test_flow_enable_message_history(flow_instance):
# Test enabling message history recording # Test enabling message history recording
flow_instance.enable_message_history() flow_instance.enable_message_history()
response = flow_instance.run("This message should be recorded in history.") response = flow_instance.run(
"This message should be recorded in history."
)
assert "This message should be recorded in history." in response assert "This message should be recorded in history." in response
assert len(flow_instance.get_message_history()) == 1 assert len(flow_instance.get_message_history()) == 1
@ -728,7 +795,9 @@ def test_flow_custom_logger(flow_instance):
custom_logger = logger # Replace with your custom logger class custom_logger = logger # Replace with your custom logger class
flow_instance.set_logger(custom_logger) flow_instance.set_logger(custom_logger)
response = flow_instance.run("Custom logger test") response = flow_instance.run("Custom logger test")
assert ("Logged using custom logger" in response) # Verify logging message assert (
"Logged using custom logger" in response
) # Verify logging message
def test_flow_batch_processing(flow_instance): def test_flow_batch_processing(flow_instance):
@ -802,35 +871,43 @@ def test_flow_input_validation(flow_instance):
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_message_delimiter( flow_instance.set_message_delimiter(
"") # Empty delimiter, should raise ValueError ""
) # Empty delimiter, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_message_delimiter( flow_instance.set_message_delimiter(
None) # None delimiter, should raise ValueError None
) # None delimiter, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_message_delimiter( flow_instance.set_message_delimiter(
123) # Invalid delimiter type, should raise ValueError 123
) # Invalid delimiter type, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_logger( flow_instance.set_logger(
"invalid_logger") # Invalid logger type, should raise ValueError "invalid_logger"
) # Invalid logger type, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.add_context(None, flow_instance.add_context(
"value") # None key, should raise ValueError None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.add_context("key", flow_instance.add_context(
None) # None value, should raise ValueError "key", None
) # None value, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.update_context( flow_instance.update_context(
None, "value") # None key, should raise ValueError None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.update_context( flow_instance.update_context(
"key", None) # None value, should raise ValueError "key", None
) # None value, should raise ValueError
def test_flow_conversation_reset(flow_instance): def test_flow_conversation_reset(flow_instance):
@ -857,7 +934,6 @@ def test_flow_conversation_persistence(flow_instance):
def test_flow_custom_event_listener(flow_instance): def test_flow_custom_event_listener(flow_instance):
# Test custom event listener # Test custom event listener
class CustomEventListener: class CustomEventListener:
def on_message_received(self, message): def on_message_received(self, message):
pass pass
@ -869,10 +945,10 @@ def test_flow_custom_event_listener(flow_instance):
# Ensure that the custom event listener methods are called during a conversation # Ensure that the custom event listener methods are called during a conversation
with mock.patch.object( with mock.patch.object(
custom_event_listener, custom_event_listener, "on_message_received"
"on_message_received") as mock_received, mock.patch.object( ) as mock_received, mock.patch.object(
custom_event_listener, custom_event_listener, "on_response_generated"
"on_response_generated") as mock_response: ) as mock_response:
flow_instance.run("Message 1") flow_instance.run("Message 1")
mock_received.assert_called_once() mock_received.assert_called_once()
mock_response.assert_called_once() mock_response.assert_called_once()
@ -881,7 +957,6 @@ def test_flow_custom_event_listener(flow_instance):
def test_flow_multiple_event_listeners(flow_instance): def test_flow_multiple_event_listeners(flow_instance):
# Test multiple event listeners # Test multiple event listeners
class FirstEventListener: class FirstEventListener:
def on_message_received(self, message): def on_message_received(self, message):
pass pass
@ -889,7 +964,6 @@ def test_flow_multiple_event_listeners(flow_instance):
pass pass
class SecondEventListener: class SecondEventListener:
def on_message_received(self, message): def on_message_received(self, message):
pass pass
@ -903,14 +977,14 @@ def test_flow_multiple_event_listeners(flow_instance):
# Ensure that both event listeners receive events during a conversation # Ensure that both event listeners receive events during a conversation
with mock.patch.object( with mock.patch.object(
first_event_listener, first_event_listener, "on_message_received"
"on_message_received") as mock_first_received, mock.patch.object( ) as mock_first_received, mock.patch.object(
first_event_listener, "on_response_generated" first_event_listener, "on_response_generated"
) as mock_first_response, mock.patch.object( ) as mock_first_response, mock.patch.object(
second_event_listener, "on_message_received" second_event_listener, "on_message_received"
) as mock_second_received, mock.patch.object( ) as mock_second_received, mock.patch.object(
second_event_listener, second_event_listener, "on_response_generated"
"on_response_generated") as mock_second_response: ) as mock_second_response:
flow_instance.run("Message 1") flow_instance.run("Message 1")
mock_first_received.assert_called_once() mock_first_received.assert_called_once()
mock_first_response.assert_called_once() mock_first_response.assert_called_once()
@ -923,31 +997,38 @@ def test_flow_error_handling(flow_instance):
# Test error handling and exceptions # Test error handling and exceptions
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_message_delimiter( flow_instance.set_message_delimiter(
"") # Empty delimiter, should raise ValueError ""
) # Empty delimiter, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_message_delimiter( flow_instance.set_message_delimiter(
None) # None delimiter, should raise ValueError None
) # None delimiter, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.set_logger( flow_instance.set_logger(
"invalid_logger") # Invalid logger type, should raise ValueError "invalid_logger"
) # Invalid logger type, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.add_context(None, flow_instance.add_context(
"value") # None key, should raise ValueError None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.add_context("key", flow_instance.add_context(
None) # None value, should raise ValueError "key", None
) # None value, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.update_context( flow_instance.update_context(
None, "value") # None key, should raise ValueError None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
flow_instance.update_context( flow_instance.update_context(
"key", None) # None value, should raise ValueError "key", None
) # None value, should raise ValueError
def test_flow_context_operations(flow_instance): def test_flow_context_operations(flow_instance):
@ -984,8 +1065,14 @@ def test_flow_custom_response(flow_instance):
flow_instance.set_response_generator(custom_response_generator) flow_instance.set_response_generator(custom_response_generator)
assert flow_instance.run("Hello") == "Hi there!" assert flow_instance.run("Hello") == "Hi there!"
assert (flow_instance.run("How are you?") == "I'm doing well, thank you.") assert (
assert (flow_instance.run("What's your name?") == "I don't understand.") flow_instance.run("How are you?")
== "I'm doing well, thank you."
)
assert (
flow_instance.run("What's your name?")
== "I don't understand."
)
def test_flow_message_validation(flow_instance): def test_flow_message_validation(flow_instance):
@ -996,8 +1083,12 @@ def test_flow_message_validation(flow_instance):
flow_instance.set_message_validator(custom_message_validator) flow_instance.set_message_validator(custom_message_validator)
assert flow_instance.run("Valid message") is not None assert flow_instance.run("Valid message") is not None
assert (flow_instance.run("") is None) # Empty message should be rejected assert (
assert (flow_instance.run(None) is None) # None message should be rejected flow_instance.run("") is None
) # Empty message should be rejected
assert (
flow_instance.run(None) is None
) # None message should be rejected
def test_flow_custom_logging(flow_instance): def test_flow_custom_logging(flow_instance):
@ -1022,10 +1113,15 @@ def test_flow_complex_use_case(flow_instance):
flow_instance.add_context("user_id", "12345") flow_instance.add_context("user_id", "12345")
flow_instance.run("Hello") flow_instance.run("Hello")
flow_instance.run("How can I help you?") flow_instance.run("How can I help you?")
assert (flow_instance.get_response() == "Please provide more details.") assert (
flow_instance.get_response() == "Please provide more details."
)
flow_instance.update_context("user_id", "54321") flow_instance.update_context("user_id", "54321")
flow_instance.run("I need help with my order") flow_instance.run("I need help with my order")
assert (flow_instance.get_response() == "Sure, I can assist with that.") assert (
flow_instance.get_response()
== "Sure, I can assist with that."
)
flow_instance.reset_conversation() flow_instance.reset_conversation()
assert len(flow_instance.get_message_history()) == 0 assert len(flow_instance.get_message_history()) == 0
assert flow_instance.get_context("user_id") is None assert flow_instance.get_context("user_id") is None
@ -1064,7 +1160,9 @@ def test_flow_concurrent_requests(flow_instance):
def test_flow_custom_timeout(flow_instance): def test_flow_custom_timeout(flow_instance):
# Test custom timeout handling # Test custom timeout handling
flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds flow_instance.set_timeout(
10
) # Set a custom timeout of 10 seconds
assert flow_instance.get_timeout() == 10 assert flow_instance.get_timeout() == 10
import time import time
@ -1115,10 +1213,16 @@ def test_flow_agent_history_prompt(flow_instance):
history = ["User: Hi", "AI: Hello"] history = ["User: Hi", "AI: Hello"]
agent_history_prompt = flow_instance.agent_history_prompt( agent_history_prompt = flow_instance.agent_history_prompt(
system_prompt, history) system_prompt, history
)
assert ("SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt) assert (
assert ("History: ['User: Hi', 'AI: Hello']" in agent_history_prompt) "SYSTEM_PROMPT: This is the system prompt."
in agent_history_prompt
)
assert (
"History: ['User: Hi', 'AI: Hello']" in agent_history_prompt
)
async def test_flow_run_concurrent(flow_instance): async def test_flow_run_concurrent(flow_instance):
@ -1133,18 +1237,9 @@ async def test_flow_run_concurrent(flow_instance):
def test_flow_bulk_run(flow_instance): def test_flow_bulk_run(flow_instance):
# Test bulk running of tasks # Test bulk running of tasks
input_data = [ input_data = [
{ {"task": "Task 1", "param1": "value1"},
"task": "Task 1", {"task": "Task 2", "param2": "value2"},
"param1": "value1" {"task": "Task 3", "param3": "value3"},
},
{
"task": "Task 2",
"param2": "value2"
},
{
"task": "Task 3",
"param3": "value3"
},
] ]
responses = flow_instance.bulk_run(input_data) responses = flow_instance.bulk_run(input_data)
@ -1159,7 +1254,9 @@ def test_flow_from_llm_and_template():
llm_instance = mocked_llm # Replace with your LLM class llm_instance = mocked_llm # Replace with your LLM class
template = "This is a template for testing." template = "This is a template for testing."
flow_instance = Agent.from_llm_and_template(llm_instance, template) flow_instance = Agent.from_llm_and_template(
llm_instance, template
)
assert isinstance(flow_instance, Agent) assert isinstance(flow_instance, Agent)
@ -1168,10 +1265,12 @@ def test_flow_from_llm_and_template_file():
# Test creating Agent instance from an LLM and a template file # Test creating Agent instance from an LLM and a template file
llm_instance = mocked_llm # Replace with your LLM class llm_instance = mocked_llm # Replace with your LLM class
template_file = ( # Create a template file for testing template_file = ( # Create a template file for testing
"template.txt") "template.txt"
)
flow_instance = Agent.from_llm_and_template_file(llm_instance, flow_instance = Agent.from_llm_and_template_file(
template_file) llm_instance, template_file
)
assert isinstance(flow_instance, Agent) assert isinstance(flow_instance, Agent)

@ -44,7 +44,9 @@ def test_autoscaler_run():
agent.id, agent.id,
"Generate a 10,000 word blog on health and wellness.", "Generate a 10,000 word blog on health and wellness.",
) )
assert (out == "Generate a 10,000 word blog on health and wellness.") assert (
out == "Generate a 10,000 word blog on health and wellness."
)
def test_autoscaler_add_agent(): def test_autoscaler_add_agent():
@ -237,7 +239,9 @@ def test_autoscaler_add_task():
def test_autoscaler_scale_up(): def test_autoscaler_scale_up():
autoscaler = AutoScaler(initial_agents=5, scale_up_factor=2, agent=agent) autoscaler = AutoScaler(
initial_agents=5, scale_up_factor=2, agent=agent
)
autoscaler.scale_up() autoscaler.scale_up()
assert len(autoscaler.agents_pool) == 10 assert len(autoscaler.agents_pool) == 10

@ -7,7 +7,6 @@ from swarms.structs.base import BaseStructure
class TestBaseStructure: class TestBaseStructure:
def test_init(self): def test_init(self):
base_structure = BaseStructure( base_structure = BaseStructure(
name="TestStructure", name="TestStructure",
@ -89,8 +88,11 @@ class TestBaseStructure:
with open(log_file) as file: with open(log_file) as file:
lines = file.readlines() lines = file.readlines()
assert len(lines) == 1 assert len(lines) == 1
assert (lines[0] == f"[{base_structure._current_timestamp()}]" assert (
f" [{event_type}] {event}\n") lines[0]
== f"[{base_structure._current_timestamp()}]"
f" [{event_type}] {event}\n"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async(self): async def test_run_async(self):
@ -134,7 +136,9 @@ class TestBaseStructure:
artifact = {"key": "value"} artifact = {"key": "value"}
artifact_name = "test_artifact" artifact_name = "test_artifact"
await base_structure.save_artifact_async(artifact, artifact_name) await base_structure.save_artifact_async(
artifact, artifact_name
)
loaded_artifact = base_structure.load_artifact(artifact_name) loaded_artifact = base_structure.load_artifact(artifact_name)
assert loaded_artifact == artifact assert loaded_artifact == artifact
@ -147,7 +151,8 @@ class TestBaseStructure:
artifact = {"key": "value"} artifact = {"key": "value"}
artifact_name = "test_artifact" artifact_name = "test_artifact"
base_structure.save_artifact(artifact, artifact_name) base_structure.save_artifact(artifact, artifact_name)
loaded_artifact = await base_structure.load_artifact_async(artifact_name loaded_artifact = await base_structure.load_artifact_async(
artifact_name
) )
assert loaded_artifact == artifact assert loaded_artifact == artifact
@ -165,8 +170,11 @@ class TestBaseStructure:
with open(log_file) as file: with open(log_file) as file:
lines = file.readlines() lines = file.readlines()
assert len(lines) == 1 assert len(lines) == 1
assert (lines[0] == f"[{base_structure._current_timestamp()}]" assert (
f" [{event_type}] {event}\n") lines[0]
== f"[{base_structure._current_timestamp()}]"
f" [{event_type}] {event}\n"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_asave_to_file(self, tmpdir): async def test_asave_to_file(self, tmpdir):
@ -193,14 +201,18 @@ class TestBaseStructure:
def test_run_in_thread(self): def test_run_in_thread(self):
base_structure = BaseStructure() base_structure = BaseStructure()
result = base_structure.run_in_thread(lambda: "Thread Test Result") result = base_structure.run_in_thread(
lambda: "Thread Test Result"
)
assert result.result() == "Thread Test Result" assert result.result() == "Thread Test Result"
def test_save_and_decompress_data(self): def test_save_and_decompress_data(self):
base_structure = BaseStructure() base_structure = BaseStructure()
data = {"key": "value"} data = {"key": "value"}
compressed_data = base_structure.compress_data(data) compressed_data = base_structure.compress_data(data)
decompressed_data = base_structure.decompres_data(compressed_data) decompressed_data = base_structure.decompres_data(
compressed_data
)
assert decompressed_data == data assert decompressed_data == data
def test_run_batched(self): def test_run_batched(self):
@ -210,11 +222,13 @@ class TestBaseStructure:
return f"Processed {data}" return f"Processed {data}"
batched_data = list(range(10)) batched_data = list(range(10))
result = base_structure.run_batched(batched_data, result = base_structure.run_batched(
batch_size=5, batched_data, batch_size=5, func=run_function
func=run_function) )
expected_result = [f"Processed {data}" for data in batched_data] expected_result = [
f"Processed {data}" for data in batched_data
]
assert result == expected_result assert result == expected_result
def test_load_config(self, tmpdir): def test_load_config(self, tmpdir):
@ -232,12 +246,15 @@ class TestBaseStructure:
tmp_dir = tmpdir.mkdir("test_dir") tmp_dir = tmpdir.mkdir("test_dir")
base_structure = BaseStructure() base_structure = BaseStructure()
data_to_backup = {"key": "value"} data_to_backup = {"key": "value"}
base_structure.backup_data(data_to_backup, backup_path=tmp_dir) base_structure.backup_data(
data_to_backup, backup_path=tmp_dir
)
backup_files = os.listdir(tmp_dir) backup_files = os.listdir(tmp_dir)
assert len(backup_files) == 1 assert len(backup_files) == 1
loaded_data = base_structure.load_from_file( loaded_data = base_structure.load_from_file(
os.path.join(tmp_dir, backup_files[0])) os.path.join(tmp_dir, backup_files[0])
)
assert loaded_data == data_to_backup assert loaded_data == data_to_backup
def test_monitor_resources(self): def test_monitor_resources(self):
@ -262,9 +279,11 @@ class TestBaseStructure:
return f"Processed {data}" return f"Processed {data}"
batched_data = list(range(10)) batched_data = list(range(10))
result = base_structure.run_with_resources_batched(batched_data, result = base_structure.run_with_resources_batched(
batch_size=5, batched_data, batch_size=5, func=run_function
func=run_function) )
expected_result = [f"Processed {data}" for data in batched_data] expected_result = [
f"Processed {data}" for data in batched_data
]
assert result == expected_result assert result == expected_result

@ -30,8 +30,13 @@ def test_load_workflow_state():
workflow.load_workflow_state("workflow_state.json") workflow.load_workflow_state("workflow_state.json")
assert workflow.max_loops == 1 assert workflow.max_loops == 1
assert len(workflow.tasks) == 2 assert len(workflow.tasks) == 2
assert (workflow.tasks[0].description == "What's the weather in miami") assert (
assert (workflow.tasks[1].description == "Create a report on these metrics") workflow.tasks[0].description == "What's the weather in miami"
)
assert (
workflow.tasks[1].description
== "Create a report on these metrics"
)
teardown_workflow() teardown_workflow()

@ -18,7 +18,9 @@ def test_run():
workflow.add(task1) workflow.add(task1)
workflow.add(task2) workflow.add(task2)
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor: with patch(
"concurrent.futures.ThreadPoolExecutor"
) as mock_executor:
future1 = Future() future1 = Future()
future1.set_result(None) future1.set_result(None)
future2 = Future() future2 = Future()

@ -87,13 +87,16 @@ def test_return_history_as_string_with_different_roles(role, content):
@pytest.mark.parametrize("message_count", range(1, 11)) @pytest.mark.parametrize("message_count", range(1, 11))
def test_return_history_as_string_with_multiple_messages(message_count,): def test_return_history_as_string_with_multiple_messages(
message_count,
):
conv = Conversation() conv = Conversation()
for i in range(message_count): for i in range(message_count):
conv.add("user", f"Message {i + 1}") conv.add("user", f"Message {i + 1}")
result = conv.return_history_as_string() result = conv.return_history_as_string()
expected = "".join( expected = "".join(
[f"user: Message {i + 1}\n\n" for i in range(message_count)]) [f"user: Message {i + 1}\n\n" for i in range(message_count)]
)
assert result == expected assert result == expected
@ -119,8 +122,10 @@ def test_return_history_as_string_with_large_message(conversation):
large_message = "Hello, world! " * 10000 # 10,000 repetitions large_message = "Hello, world! " * 10000 # 10,000 repetitions
conversation.add("user", large_message) conversation.add("user", large_message)
result = conversation.return_history_as_string() result = conversation.return_history_as_string()
expected = ("user: Hello, world!\n\nassistant: Hello, user!\n\nuser:" expected = (
f" {large_message}\n\n") "user: Hello, world!\n\nassistant: Hello, user!\n\nuser:"
f" {large_message}\n\n"
)
assert result == expected assert result == expected
@ -136,8 +141,10 @@ def test_export_import_conversation(conversation, tmp_path):
conversation.export_conversation(filename) conversation.export_conversation(filename)
new_conversation = Conversation() new_conversation = Conversation()
new_conversation.import_conversation(filename) new_conversation.import_conversation(filename)
assert (new_conversation.return_history_as_string() == assert (
conversation.return_history_as_string()) new_conversation.return_history_as_string()
== conversation.return_history_as_string()
)
def test_count_messages_by_role(conversation): def test_count_messages_by_role(conversation):

@ -11,7 +11,6 @@ llm2 = Anthropic()
# Mock the OpenAI class for testing # Mock the OpenAI class for testing
class MockOpenAI: class MockOpenAI:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -140,9 +139,9 @@ def test_groupchat_manager_generate_reply():
selector = agent1 selector = agent1
# Initialize GroupChatManager # Initialize GroupChatManager
manager = GroupChatManager(groupchat=groupchat, manager = GroupChatManager(
selector=selector, groupchat=groupchat, selector=selector, openai=mocked_openai
openai=mocked_openai) )
# Generate a reply # Generate a reply
task = "Write me a riddle" task = "Write me a riddle"
@ -166,8 +165,9 @@ def test_groupchat_select_speaker():
# Simulate selecting the next speaker # Simulate selecting the next speaker
last_speaker = agent1 last_speaker = agent1
next_speaker = manager.select_speaker(last_speaker=last_speaker, next_speaker = manager.select_speaker(
selector=selector) last_speaker=last_speaker, selector=selector
)
# Ensure the next speaker is agent2 # Ensure the next speaker is agent2
assert next_speaker == agent2 assert next_speaker == agent2
@ -185,8 +185,9 @@ def test_groupchat_underpopulated_group():
# Simulate selecting the next speaker in an underpopulated group # Simulate selecting the next speaker in an underpopulated group
last_speaker = agent1 last_speaker = agent1
next_speaker = manager.select_speaker(last_speaker=last_speaker, next_speaker = manager.select_speaker(
selector=selector) last_speaker=last_speaker, selector=selector
)
# Ensure the next speaker is the same as the last speaker in an underpopulated group # Ensure the next speaker is the same as the last speaker in an underpopulated group
assert next_speaker == last_speaker assert next_speaker == last_speaker
@ -204,13 +205,15 @@ def test_groupchat_max_rounds():
# Simulate the conversation with max rounds # Simulate the conversation with max rounds
last_speaker = agent1 last_speaker = agent1
for _ in range(2): for _ in range(2):
next_speaker = manager.select_speaker(last_speaker=last_speaker, next_speaker = manager.select_speaker(
selector=selector) last_speaker=last_speaker, selector=selector
)
last_speaker = next_speaker last_speaker = next_speaker
# Try one more round, should stay with the last speaker # Try one more round, should stay with the last speaker
next_speaker = manager.select_speaker(last_speaker=last_speaker, next_speaker = manager.select_speaker(
selector=selector) last_speaker=last_speaker, selector=selector
)
# Ensure the next speaker is the same as the last speaker after reaching max rounds # Ensure the next speaker is the same as the last speaker after reaching max rounds
assert next_speaker == last_speaker assert next_speaker == last_speaker

@ -15,8 +15,10 @@ def valid_schema_path(tmp_path):
d = tmp_path / "sub" d = tmp_path / "sub"
d.mkdir() d.mkdir()
p = d / "schema.json" p = d / "schema.json"
p.write_text('{"type": "object", "properties": {"name": {"type":' p.write_text(
' "string"}}}') '{"type": "object", "properties": {"name": {"type":'
' "string"}}}'
)
return str(p) return str(p)
@ -31,7 +33,6 @@ def invalid_schema_path(tmp_path):
# This test class must be subclassed as JSON class is abstract # This test class must be subclassed as JSON class is abstract
class TestableJSON(JSON): class TestableJSON(JSON):
def validate(self, data): def validate(self, data):
# Here must be a real validation implementation for testing # Here must be a real validation implementation for testing
pass pass

@ -35,9 +35,15 @@ def test_majority_voting_run_concurrent(mocker):
majority_vote = mv.run("What is the capital of France?") majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task # Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with("What is the capital of France?") agent1.run.assert_called_once_with(
agent2.run.assert_called_once_with("What is the capital of France?") "What is the capital of France?"
agent3.run.assert_called_once_with("What is the capital of France?") )
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses # Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0]) conversation.add.assert_any_call(agent1.agent_name, results[0])
@ -77,9 +83,15 @@ def test_majority_voting_run_multithreaded(mocker):
majority_vote = mv.run("What is the capital of France?") majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task # Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with("What is the capital of France?") agent1.run.assert_called_once_with(
agent2.run.assert_called_once_with("What is the capital of France?") "What is the capital of France?"
agent3.run.assert_called_once_with("What is the capital of France?") )
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses # Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0]) conversation.add.assert_any_call(agent1.agent_name, results[0])
@ -121,9 +133,15 @@ async def test_majority_voting_run_asynchronous(mocker):
majority_vote = await mv.run("What is the capital of France?") majority_vote = await mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task # Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with("What is the capital of France?") agent1.run.assert_called_once_with(
agent2.run.assert_called_once_with("What is the capital of France?") "What is the capital of France?"
agent3.run.assert_called_once_with("What is the capital of France?") )
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses # Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0]) conversation.add.assert_any_call(agent1.agent_name, results[0])

@ -8,7 +8,9 @@ def test_message_pool_initialization():
agent2 = Agent(llm=OpenAIChat(), agent_name="agent1") agent2 = Agent(llm=OpenAIChat(), agent_name="agent1")
moderator = Agent(llm=OpenAIChat(), agent_name="agent1") moderator = Agent(llm=OpenAIChat(), agent_name="agent1")
agents = [agent1, agent2] agents = [agent1, agent2]
message_pool = MessagePool(agents=agents, moderator=moderator, turns=5) message_pool = MessagePool(
agents=agents, moderator=moderator, turns=5
)
assert message_pool.agent == agents assert message_pool.agent == agents
assert message_pool.moderator == moderator assert message_pool.moderator == moderator
@ -18,21 +20,27 @@ def test_message_pool_initialization():
def test_message_pool_add(): def test_message_pool_add():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.messages == [{ assert message_pool.messages == [
{
"agent": agent1, "agent": agent1,
"content": "Hello, world!", "content": "Hello, world!",
"turn": 1, "turn": 1,
"visible_to": "all", "visible_to": "all",
"logged": True, "logged": True,
}] }
]
def test_message_pool_reset(): def test_message_pool_reset():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.add(agent=agent1, content="Hello, world!", turn=1)
message_pool.reset() message_pool.reset()
@ -41,7 +49,9 @@ def test_message_pool_reset():
def test_message_pool_last_turn(): def test_message_pool_last_turn():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.last_turn() == 1 assert message_pool.last_turn() == 1
@ -49,7 +59,9 @@ def test_message_pool_last_turn():
def test_message_pool_last_message(): def test_message_pool_last_message():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.last_message == { assert message_pool.last_message == {
@ -63,24 +75,28 @@ def test_message_pool_last_message():
def test_message_pool_get_all_messages(): def test_message_pool_get_all_messages():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.get_all_messages() == [{ assert message_pool.get_all_messages() == [
{
"agent": agent1, "agent": agent1,
"content": "Hello, world!", "content": "Hello, world!",
"turn": 1, "turn": 1,
"visible_to": "all", "visible_to": "all",
"logged": True, "logged": True,
}] }
]
def test_message_pool_get_visible_messages(): def test_message_pool_get_visible_messages():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
agent2 = Agent(agent_name="agent2") agent2 = Agent(agent_name="agent2")
message_pool = MessagePool(agents=[agent1, agent2], message_pool = MessagePool(
moderator=agent1, agents=[agent1, agent2], moderator=agent1, turns=5
turns=5) )
message_pool.add( message_pool.add(
agent=agent1, agent=agent1,
content="Hello, agent2!", content="Hello, agent2!",
@ -88,10 +104,14 @@ def test_message_pool_get_visible_messages():
visible_to=[agent2.agent_name], visible_to=[agent2.agent_name],
) )
assert message_pool.get_visible_messages(agent=agent2, turn=2) == [{ assert message_pool.get_visible_messages(
agent=agent2, turn=2
) == [
{
"agent": agent1, "agent": agent1,
"content": "Hello, agent2!", "content": "Hello, agent2!",
"turn": 1, "turn": 1,
"visible_to": [agent2.agent_name], "visible_to": [agent2.agent_name],
"logged": True, "logged": True,
}] }
]

@ -11,9 +11,7 @@ from swarms.structs.model_parallizer import ModelParallelizer
# Initialize the models # Initialize the models
custom_config = { custom_config = {
"quantize": True, "quantize": True,
"quantization_config": { "quantization_config": {"load_in_4bit": True},
"load_in_4bit": True
},
"verbose": True, "verbose": True,
} }
huggingface_llm = HuggingfaceLLM( huggingface_llm = HuggingfaceLLM(
@ -26,12 +24,14 @@ zeroscope_ttv = ZeroscopeTTV()
def test_init(): def test_init():
mp = ModelParallelizer([ mp = ModelParallelizer(
[
huggingface_llm, huggingface_llm,
mixtral, mixtral,
gpt4_vision_api, gpt4_vision_api,
zeroscope_ttv, zeroscope_ttv,
]) ]
)
assert isinstance(mp, ModelParallelizer) assert isinstance(mp, ModelParallelizer)
@ -39,20 +39,24 @@ def test_run():
mp = ModelParallelizer([huggingface_llm]) mp = ModelParallelizer([huggingface_llm])
result = mp.run( result = mp.run(
"Create a list of known biggest risks of structural collapse" "Create a list of known biggest risks of structural collapse"
" with references") " with references"
)
assert isinstance(result, str) assert isinstance(result, str)
def test_run_all(): def test_run_all():
mp = ModelParallelizer([ mp = ModelParallelizer(
[
huggingface_llm, huggingface_llm,
mixtral, mixtral,
gpt4_vision_api, gpt4_vision_api,
zeroscope_ttv, zeroscope_ttv,
]) ]
)
result = mp.run_all( result = mp.run_all(
"Create a list of known biggest risks of structural collapse" "Create a list of known biggest risks of structural collapse"
" with references") " with references"
)
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 5 assert len(result) == 5
@ -71,8 +75,10 @@ def test_remove_llm():
def test_save_responses_to_file(tmp_path): def test_save_responses_to_file(tmp_path):
mp = ModelParallelizer([huggingface_llm]) mp = ModelParallelizer([huggingface_llm])
mp.run("Create a list of known biggest risks of structural collapse" mp.run(
" with references") "Create a list of known biggest risks of structural collapse"
" with references"
)
file = tmp_path / "responses.txt" file = tmp_path / "responses.txt"
mp.save_responses_to_file(file) mp.save_responses_to_file(file)
assert file.read_text() != "" assert file.read_text() != ""
@ -80,8 +86,10 @@ def test_save_responses_to_file(tmp_path):
def test_get_task_history(): def test_get_task_history():
mp = ModelParallelizer([huggingface_llm]) mp = ModelParallelizer([huggingface_llm])
mp.run("Create a list of known biggest risks of structural collapse" mp.run(
" with references") "Create a list of known biggest risks of structural collapse"
" with references"
)
assert mp.get_task_history() == [ assert mp.get_task_history() == [
"Create a list of known biggest risks of structural collapse" "Create a list of known biggest risks of structural collapse"
" with references" " with references"
@ -90,8 +98,10 @@ def test_get_task_history():
def test_summary(capsys): def test_summary(capsys):
mp = ModelParallelizer([huggingface_llm]) mp = ModelParallelizer([huggingface_llm])
mp.run("Create a list of known biggest risks of structural collapse" mp.run(
" with references") "Create a list of known biggest risks of structural collapse"
" with references"
)
mp.summary() mp.summary()
captured = capsys.readouterr() captured = capsys.readouterr()
assert "Tasks History:" in captured.out assert "Tasks History:" in captured.out
@ -113,7 +123,8 @@ def test_concurrent_run():
mp = ModelParallelizer([huggingface_llm, mixtral]) mp = ModelParallelizer([huggingface_llm, mixtral])
result = mp.concurrent_run( result = mp.concurrent_run(
"Create a list of known biggest risks of structural collapse" "Create a list of known biggest risks of structural collapse"
" with references") " with references"
)
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 2 assert len(result) == 2

@ -73,8 +73,12 @@ def test_run(collaboration):
def test_format_results(collaboration): def test_format_results(collaboration):
collaboration.results = [{"agent": "Agent1", "response": "Response1"}] collaboration.results = [
formatted_results = collaboration.format_results(collaboration.results) {"agent": "Agent1", "response": "Response1"}
]
formatted_results = collaboration.format_results(
collaboration.results
)
assert "Agent1 responded: Response1" in formatted_results assert "Agent1 responded: Response1" in formatted_results
@ -108,10 +112,7 @@ def test_repr(collaboration):
def test_load(collaboration): def test_load(collaboration):
state = { state = {
"step": 5, "step": 5,
"results": [{ "results": [{"agent": "Agent1", "response": "Response1"}],
"agent": "Agent1",
"response": "Response1"
}],
} }
with open(collaboration.saved_file_path_name, "w") as file: with open(collaboration.saved_file_path_name, "w") as file:
json.dump(state, file) json.dump(state, file)

@ -5,7 +5,6 @@ from swarms.structs import NonlinearWorkflow, Task
class TestNonlinearWorkflow: class TestNonlinearWorkflow:
def test_add_task(self): def test_add_task(self):
llm = OpenAIChat(openai_api_key="") llm = OpenAIChat(openai_api_key="")
task = Task(llm, "What's the weather in miami") task = Task(llm, "What's the weather in miami")
@ -34,7 +33,9 @@ class TestNonlinearWorkflow:
workflow = NonlinearWorkflow() workflow = NonlinearWorkflow()
workflow.add(task1, task2.name) workflow.add(task1, task2.name)
workflow.add(task2, task1.name) workflow.add(task2, task1.name)
with pytest.raises(Exception, match="Circular dependency detected"): with pytest.raises(
Exception, match="Circular dependency detected"
):
workflow.run() workflow.run()
def test_run_with_stopping_token(self): def test_run_with_stopping_token(self):

@ -53,7 +53,9 @@ def test_run_stop_token_not_in_result():
try: try:
workflow.run() workflow.run()
except RecursionError: except RecursionError:
pytest.fail("RecursiveWorkflow.run caused a RecursionError") pytest.fail(
"RecursiveWorkflow.run caused a RecursionError"
)
assert agent.execute.call_count == max_iterations assert agent.execute.call_count == max_iterations

@ -17,7 +17,6 @@ os.environ["OPENAI_API_KEY"] = "mocked_api_key"
# Mock OpenAIChat class for testing # Mock OpenAIChat class for testing
class MockOpenAIChat: class MockOpenAIChat:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -27,7 +26,6 @@ class MockOpenAIChat:
# Mock Agent class for testing # Mock Agent class for testing
class MockAgent: class MockAgent:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -37,7 +35,6 @@ class MockAgent:
# Mock SequentialWorkflow class for testing # Mock SequentialWorkflow class for testing
class MockSequentialWorkflow: class MockSequentialWorkflow:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -72,7 +69,10 @@ def test_sequential_workflow_initialization():
assert len(workflow.tasks) == 0 assert len(workflow.tasks) == 0
assert workflow.max_loops == 1 assert workflow.max_loops == 1
assert workflow.autosave is False assert workflow.autosave is False
assert (workflow.saved_state_filepath == "sequential_workflow_state.json") assert (
workflow.saved_state_filepath
== "sequential_workflow_state.json"
)
assert workflow.restore_state_filepath is None assert workflow.restore_state_filepath is None
assert workflow.dashboard is False assert workflow.dashboard is False
@ -177,7 +177,6 @@ def test_sequential_workflow_workflow_dashboard(capfd):
# Mock Agent class for async testing # Mock Agent class for async testing
class MockAsyncAgent: class MockAsyncAgent:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass

@ -20,7 +20,9 @@ def test_swarm_network_init(swarm_network):
@patch("swarms.structs.swarm_net.SwarmNetwork.logger") @patch("swarms.structs.swarm_net.SwarmNetwork.logger")
def test_run(mock_logger, swarm_network): def test_run(mock_logger, swarm_network):
swarm_network.run() swarm_network.run()
assert (mock_logger.info.call_count == 10) # 2 log messages per agent assert (
mock_logger.info.call_count == 10
) # 2 log messages per agent
def test_run_with_mocked_agents(mocker, swarm_network): def test_run_with_mocked_agents(mocker, swarm_network):

@ -7,7 +7,8 @@ from dotenv import load_dotenv
from swarms.models.gpt4_vision_api import GPT4VisionAPI from swarms.models.gpt4_vision_api import GPT4VisionAPI
from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,) MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
)
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
from swarms.structs.task import Task from swarms.structs.task import Task
@ -20,11 +21,13 @@ def llm():
def test_agent_run_task(llm): def test_agent_run_task(llm):
task = ("Analyze this image of an assembly line and identify any" task = (
"Analyze this image of an assembly line and identify any"
" issues such as misaligned parts, defects, or deviations" " issues such as misaligned parts, defects, or deviations"
" from the standard assembly process. IF there is anything" " from the standard assembly process. IF there is anything"
" unsafe in the image, explain why it is unsafe and how it" " unsafe in the image, explain why it is unsafe and how it"
" could be improved.") " could be improved."
)
img = "assembly_line.jpg" img = "assembly_line.jpg"
agent = Agent( agent = Agent(
@ -46,7 +49,9 @@ def test_agent_run_task(llm):
@pytest.fixture @pytest.fixture
def task(): def task():
agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)] agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)]
return Task(id="Task_1", task="Task_Name", agents=agents, dependencies=[]) return Task(
id="Task_1", task="Task_Name", agents=agents, dependencies=[]
)
# Basic tests # Basic tests
@ -184,7 +189,9 @@ def test_task_execute_with_condition(mocker):
mock_agent = mocker.Mock(spec=Agent) mock_agent = mocker.Mock(spec=Agent)
mock_agent.run.return_value = "result" mock_agent.run.return_value = "result"
condition = mocker.Mock(return_value=True) condition = mocker.Mock(return_value=True)
task = Task(description="Test task", agent=mock_agent, condition=condition) task = Task(
description="Test task", agent=mock_agent, condition=condition
)
task.execute() task.execute()
assert task.result == "result" assert task.result == "result"
assert task.history == ["result"] assert task.history == ["result"]
@ -194,7 +201,9 @@ def test_task_execute_with_condition_false(mocker):
mock_agent = mocker.Mock(spec=Agent) mock_agent = mocker.Mock(spec=Agent)
mock_agent.run.return_value = "result" mock_agent.run.return_value = "result"
condition = mocker.Mock(return_value=False) condition = mocker.Mock(return_value=False)
task = Task(description="Test task", agent=mock_agent, condition=condition) task = Task(
description="Test task", agent=mock_agent, condition=condition
)
task.execute() task.execute()
assert task.result is None assert task.result is None
assert task.history == [] assert task.history == []
@ -204,7 +213,9 @@ def test_task_execute_with_action(mocker):
mock_agent = mocker.Mock(spec=Agent) mock_agent = mocker.Mock(spec=Agent)
mock_agent.run.return_value = "result" mock_agent.run.return_value = "result"
action = mocker.Mock() action = mocker.Mock()
task = Task(description="Test task", agent=mock_agent, action=action) task = Task(
description="Test task", agent=mock_agent, action=action
)
task.execute() task.execute()
assert task.result == "result" assert task.result == "result"
assert task.history == ["result"] assert task.history == ["result"]
@ -232,9 +243,11 @@ def test_task_handle_scheduled_task_future(mocker):
agent=mock_agent, agent=mock_agent,
schedule_time=datetime.now() + timedelta(days=1), schedule_time=datetime.now() + timedelta(days=1),
) )
with mocker.patch.object(task.scheduler, with mocker.patch.object(
"enter") as mock_enter, mocker.patch.object( task.scheduler, "enter"
task.scheduler, "run") as mock_run: ) as mock_enter, mocker.patch.object(
task.scheduler, "run"
) as mock_run:
task.handle_scheduled_task() task.handle_scheduled_task()
mock_enter.assert_called_once() mock_enter.assert_called_once()
mock_run.assert_called_once() mock_run.assert_called_once()

@ -21,9 +21,7 @@ def agent():
@pytest.fixture() @pytest.fixture()
def concrete_task_queue(): def concrete_task_queue():
class ConcreteTaskQueue(TaskQueueBase): class ConcreteTaskQueue(TaskQueueBase):
def add_task(self, task): def add_task(self, task):
pass # Here you would add concrete implementation of add_task pass # Here you would add concrete implementation of add_task
@ -53,8 +51,9 @@ def test_add_task_failure(concrete_task_queue, task):
# Assuming the task is somehow invalid # Assuming the task is somehow invalid
# Note: Concrete implementation requires logic defining what an invalid task is # Note: Concrete implementation requires logic defining what an invalid task is
concrete_task_queue.add_task(task) concrete_task_queue.add_task(task)
assert (concrete_task_queue.add_task(task) assert (
is False) # Adding the same task again concrete_task_queue.add_task(task) is False
) # Adding the same task again
@pytest.mark.parametrize("invalid_task", [None, "", {}, []]) @pytest.mark.parametrize("invalid_task", [None, "", {}, []])

@ -7,7 +7,6 @@ from swarms.structs.team import Team
class TestTeam(unittest.TestCase): class TestTeam(unittest.TestCase):
def setUp(self): def setUp(self):
self.agent = Agent( self.agent = Agent(
llm=OpenAIChat(openai_api_key=""), llm=OpenAIChat(openai_api_key=""),
@ -31,17 +30,16 @@ class TestTeam(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.team.check_config( self.team.check_config(
{"config": json.dumps({ {"config": json.dumps({"agents": [], "tasks": []})}
"agents": [], )
"tasks": []
})})
def test_run(self): def test_run(self):
self.assertEqual(self.team.run(), self.task.execute()) self.assertEqual(self.team.run(), self.task.execute())
def test_sequential_loop(self): def test_sequential_loop(self):
self.assertEqual(self.team._Team__sequential_loop(), self.assertEqual(
self.task.execute()) self.team._Team__sequential_loop(), self.task.execute()
)
def test_log(self): def test_log(self):
self.assertIsNone(self.team._Team__log("Test message")) self.assertIsNone(self.team._Team__log("Test message"))

@ -27,7 +27,9 @@ def test_set_entry_point(graph_workflow):
def test_set_entry_point_nonexistent_node(graph_workflow): def test_set_entry_point_nonexistent_node(graph_workflow):
with pytest.raises(ValueError, match="Node does not exist in graph"): with pytest.raises(
ValueError, match="Node does not exist in graph"
):
graph_workflow.set_entry_point("nonexistent") graph_workflow.set_entry_point("nonexistent")
@ -40,23 +42,29 @@ def test_add_edge(graph_workflow):
def test_add_edge_nonexistent_node(graph_workflow): def test_add_edge_nonexistent_node(graph_workflow):
graph_workflow.add("node1", "value1") graph_workflow.add("node1", "value1")
with pytest.raises(ValueError, match="Node does not exist in graph"): with pytest.raises(
ValueError, match="Node does not exist in graph"
):
graph_workflow.add_edge("node1", "nonexistent") graph_workflow.add_edge("node1", "nonexistent")
def test_add_conditional_edges(graph_workflow): def test_add_conditional_edges(graph_workflow):
graph_workflow.add("node1", "value1") graph_workflow.add("node1", "value1")
graph_workflow.add("node2", "value2") graph_workflow.add("node2", "value2")
graph_workflow.add_conditional_edges("node1", "condition1", graph_workflow.add_conditional_edges(
{"condition_value1": "node2"}) "node1", "condition1", {"condition_value1": "node2"}
)
assert "node2" in graph_workflow.graph["node1"]["edges"] assert "node2" in graph_workflow.graph["node1"]["edges"]
def test_add_conditional_edges_nonexistent_node(graph_workflow): def test_add_conditional_edges_nonexistent_node(graph_workflow):
graph_workflow.add("node1", "value1") graph_workflow.add("node1", "value1")
with pytest.raises(ValueError, match="Node does not exist in graph"): with pytest.raises(
ValueError, match="Node does not exist in graph"
):
graph_workflow.add_conditional_edges( graph_workflow.add_conditional_edges(
"node1", "condition1", {"condition_value1": "nonexistent"}) "node1", "condition1", {"condition_value1": "nonexistent"}
)
def test_run(graph_workflow): def test_run(graph_workflow):

@ -35,8 +35,9 @@ def test_log_activity_posthog(mock_posthog, mock_env):
test_function() test_function()
# Check if the Posthog capture method was called with the expected arguments # Check if the Posthog capture method was called with the expected arguments
mock_posthog.capture.assert_called_once_with("test_user_id", event_name, mock_posthog.capture.assert_called_once_with(
event_properties) "test_user_id", event_name, event_properties
)
# Test a scenario where environment variables are not set # Test a scenario where environment variables are not set

@ -46,7 +46,9 @@ def test_generate_unique_identifier():
# Generate unique identifiers and ensure they are valid UUID strings # Generate unique identifiers and ensure they are valid UUID strings
unique_id = generate_unique_identifier() unique_id = generate_unique_identifier()
assert isinstance(unique_id, str) assert isinstance(unique_id, str)
assert uuid.UUID(unique_id, version=5, namespace=uuid.NAMESPACE_DNS) assert uuid.UUID(
unique_id, version=5, namespace=uuid.NAMESPACE_DNS
)
def test_generate_user_id_edge_case(): def test_generate_user_id_edge_case():
@ -71,7 +73,9 @@ def test_get_system_info_edge_case():
# Test get_system_info for consistency # Test get_system_info for consistency
system_info1 = get_system_info() system_info1 = get_system_info()
system_info2 = get_system_info() system_info2 = get_system_info()
assert (system_info1 == system_info2) # Ensure system info remains the same assert (
system_info1 == system_info2
) # Ensure system info remains the same
def test_generate_unique_identifier_edge_case(): def test_generate_unique_identifier_edge_case():

@ -20,7 +20,9 @@ headers = {
def run_pytest(): def run_pytest():
result = subprocess.run(["pytest"], capture_output=True, text=True) result = subprocess.run(
["pytest"], capture_output=True, text=True
)
return result.stdout + result.stderr return result.stdout + result.stderr
@ -54,7 +56,9 @@ def main():
errors = parse_pytest_output(pytest_output) errors = parse_pytest_output(pytest_output)
for error in errors: for error in errors:
issue_response = create_github_issue(error["title"], error["body"]) issue_response = create_github_issue(
error["title"], error["body"]
)
print(f"Issue created: {issue_response.get('html_url')}") print(f"Issue created: {issue_response.get('html_url')}")

@ -16,8 +16,9 @@ def test_default_max_tokens():
assert tokenizer.default_max_tokens() == 100000 assert tokenizer.default_max_tokens() == 100000
@pytest.mark.parametrize("model,tokens", [("claude-2.1", 200000), @pytest.mark.parametrize(
("claude", 100000)]) "model,tokens", [("claude-2.1", 200000), ("claude", 100000)]
)
def test_default_max_tokens_models(model, tokens): def test_default_max_tokens_models(model, tokens):
tokenizer = AnthropicTokenizer(model=model) tokenizer = AnthropicTokenizer(model=model)
assert tokenizer.default_max_tokens() == tokens assert tokenizer.default_max_tokens() == tokens

@ -18,7 +18,9 @@ def test_post_init(base_tokenizer):
# 3. Tests for count_tokens_left with different inputs. # 3. Tests for count_tokens_left with different inputs.
def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch): def test_count_tokens_left_with_positive_diff(
base_tokenizer, monkeypatch
):
# Mocking count_tokens to return a specific value # Mocking count_tokens to return a specific value
monkeypatch.setattr( monkeypatch.setattr(
"swarms.tokenizers.BaseTokenizer.count_tokens", "swarms.tokenizers.BaseTokenizer.count_tokens",
@ -27,7 +29,9 @@ def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch):
assert base_tokenizer.count_tokens_left("some text") == 50 assert base_tokenizer.count_tokens_left("some text") == 50
def test_count_tokens_left_with_zero_diff(base_tokenizer, monkeypatch): def test_count_tokens_left_with_zero_diff(
base_tokenizer, monkeypatch
):
monkeypatch.setattr( monkeypatch.setattr(
"swarms.tokenizers.BaseTokenizer.count_tokens", "swarms.tokenizers.BaseTokenizer.count_tokens",
lambda x, y: 100, lambda x, y: 100,

@ -51,10 +51,18 @@ def test_prefix_space_tokens(hftokenizer):
# testing _maybe_add_prefix_space method # testing _maybe_add_prefix_space method
def test__maybe_add_prefix_space(hftokenizer): def test__maybe_add_prefix_space(hftokenizer):
assert (hftokenizer._maybe_add_prefix_space( assert (
[101, 2003, 2010, 2050, 2001, 2339], " is why") == " is why") hftokenizer._maybe_add_prefix_space(
assert (hftokenizer._maybe_add_prefix_space([2003, 2010, 2050, 2001, 2339], [101, 2003, 2010, 2050, 2001, 2339], " is why"
"is why") == " is why") )
== " is why"
)
assert (
hftokenizer._maybe_add_prefix_space(
[2003, 2010, 2050, 2001, 2339], "is why"
)
== " is why"
)
# continuing tests for other methods... # continuing tests for other methods...

@ -18,21 +18,31 @@ def test_default_max_tokens(openai_tokenizer):
assert openai_tokenizer.default_max_tokens() == 4096 assert openai_tokenizer.default_max_tokens() == 4096
@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3), @pytest.mark.parametrize(
(["Hello"], 4)]) "text, expected_output", [("Hello, world!", 3), (["Hello"], 4)]
)
def test_count_tokens_single(openai_tokenizer, text, expected_output): def test_count_tokens_single(openai_tokenizer, text, expected_output):
assert (openai_tokenizer.count_tokens(text, "gpt-3") == expected_output) assert (
openai_tokenizer.count_tokens(text, "gpt-3")
== expected_output
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"texts, expected_output", "texts, expected_output",
[(["Hello, world!", "This is a test"], 6), (["Hello"], 4)], [(["Hello, world!", "This is a test"], 6), (["Hello"], 4)],
) )
def test_count_tokens_multiple(openai_tokenizer, texts, expected_output): def test_count_tokens_multiple(
assert (openai_tokenizer.count_tokens(texts, "gpt-3") == expected_output) openai_tokenizer, texts, expected_output
):
assert (
openai_tokenizer.count_tokens(texts, "gpt-3")
== expected_output
)
@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3), @pytest.mark.parametrize(
(["Hello"], 4)]) "text, expected_output", [("Hello, world!", 3), (["Hello"], 4)]
)
def test_len(openai_tokenizer, text, expected_output): def test_len(openai_tokenizer, text, expected_output):
assert openai_tokenizer.len(text, "gpt-3") == expected_output assert openai_tokenizer.len(text, "gpt-3") == expected_output

@ -7,7 +7,9 @@ from swarms.tokenizers.r_tokenizers import Tokenizer
def test_initializer_existing_model_file(): def test_initializer_existing_model_file():
with patch("os.path.exists", return_value=True): with patch("os.path.exists", return_value=True):
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model") tokenizer = Tokenizer("tokenizers/my_model.model")
mock_model.assert_called_with("tokenizers/my_model.model") mock_model.assert_called_with("tokenizers/my_model.model")
assert tokenizer.model == mock_model.return_value assert tokenizer.model == mock_model.return_value
@ -15,43 +17,66 @@ def test_initializer_existing_model_file():
def test_initializer_model_folder(): def test_initializer_model_folder():
with patch("os.path.exists", side_effect=[False, True]): with patch("os.path.exists", side_effect=[False, True]):
with patch("swarms.tokenizers.HuggingFaceTokenizer") as mock_model: with patch(
"swarms.tokenizers.HuggingFaceTokenizer"
) as mock_model:
tokenizer = Tokenizer("my_model_directory") tokenizer = Tokenizer("my_model_directory")
mock_model.assert_called_with("my_model_directory") mock_model.assert_called_with("my_model_directory")
assert tokenizer.model == mock_model.return_value assert tokenizer.model == mock_model.return_value
def test_vocab_size(): def test_vocab_size():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model") tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.vocab_size == mock_model.return_value.vocab_size) assert (
tokenizer.vocab_size == mock_model.return_value.vocab_size
)
def test_bos_token_id(): def test_bos_token_id():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model") tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.bos_token_id == mock_model.return_value.bos_token_id) assert (
tokenizer.bos_token_id
== mock_model.return_value.bos_token_id
)
def test_encode(): def test_encode():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model") tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.encode("hello") == assert (
mock_model.return_value.encode.return_value) tokenizer.encode("hello")
== mock_model.return_value.encode.return_value
)
def test_decode(): def test_decode():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model") tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.decode( assert (
[1, 2, 3]) == mock_model.return_value.decode.return_value) tokenizer.decode([1, 2, 3])
== mock_model.return_value.decode.return_value
)
def test_call(): def test_call():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model: with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model") tokenizer = Tokenizer("tokenizers/my_model.model")
assert ( assert (
tokenizer("hello") == mock_model.return_value.__call__.return_value) tokenizer("hello")
== mock_model.return_value.__call__.return_value
)
# More tests can be added here # More tests can be added here

@ -65,20 +65,23 @@ def test_structured_tool_invoke():
def test_tool_creation(): def test_tool_creation():
tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") tool = Tool(
name="test_tool", func=lambda x: x, description="Test tool"
)
assert tool.name == "test_tool" assert tool.name == "test_tool"
assert tool.func is not None assert tool.func is not None
assert tool.description == "Test tool" assert tool.description == "Test tool"
def test_tool_ainvoke(): def test_tool_ainvoke():
tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") tool = Tool(
name="test_tool", func=lambda x: x, description="Test tool"
)
result = tool.ainvoke("input_data") result = tool.ainvoke("input_data")
assert result == "input_data" assert result == "input_data"
def test_tool_ainvoke_with_coroutine(): def test_tool_ainvoke_with_coroutine():
async def async_function(input_data): async def async_function(input_data):
return input_data return input_data
@ -92,7 +95,6 @@ def test_tool_ainvoke_with_coroutine():
def test_tool_args(): def test_tool_args():
def sample_function(input_data): def sample_function(input_data):
return input_data return input_data
@ -108,7 +110,6 @@ def test_tool_args():
def test_structured_tool_creation(): def test_structured_tool_creation():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -125,7 +126,6 @@ def test_structured_tool_creation():
def test_structured_tool_ainvoke(): def test_structured_tool_ainvoke():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -140,7 +140,6 @@ def test_structured_tool_ainvoke():
def test_structured_tool_ainvoke_with_coroutine(): def test_structured_tool_ainvoke_with_coroutine():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -158,7 +157,6 @@ def test_structured_tool_ainvoke_with_coroutine():
def test_structured_tool_args(): def test_structured_tool_args():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -184,13 +182,14 @@ def test_tool_ainvoke_exception():
def test_tool_ainvoke_with_coroutine_exception(): def test_tool_ainvoke_with_coroutine_exception():
tool = Tool(name="test_tool", coroutine=None, description="Test tool") tool = Tool(
name="test_tool", coroutine=None, description="Test tool"
)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
tool.ainvoke("input_data") tool.ainvoke("input_data")
def test_structured_tool_ainvoke_exception(): def test_structured_tool_ainvoke_exception():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -205,7 +204,6 @@ def test_structured_tool_ainvoke_exception():
def test_structured_tool_ainvoke_with_coroutine_exception(): def test_structured_tool_ainvoke_with_coroutine_exception():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -227,7 +225,6 @@ def test_tool_description_not_provided():
def test_tool_invoke_with_callbacks(): def test_tool_invoke_with_callbacks():
def sample_function(input_data, callbacks=None): def sample_function(input_data, callbacks=None):
if callbacks: if callbacks:
callbacks.on_start() callbacks.on_start()
@ -243,7 +240,6 @@ def test_tool_invoke_with_callbacks():
def test_tool_invoke_with_new_argument(): def test_tool_invoke_with_new_argument():
def sample_function(input_data, callbacks=None): def sample_function(input_data, callbacks=None):
return input_data return input_data
@ -253,7 +249,6 @@ def test_tool_invoke_with_new_argument():
def test_tool_ainvoke_with_new_argument(): def test_tool_ainvoke_with_new_argument():
async def async_function(input_data, callbacks=None): async def async_function(input_data, callbacks=None):
return input_data return input_data
@ -263,7 +258,6 @@ def test_tool_ainvoke_with_new_argument():
def test_tool_description_from_docstring(): def test_tool_description_from_docstring():
def sample_function(input_data): def sample_function(input_data):
"""Sample function docstring""" """Sample function docstring"""
return input_data return input_data
@ -273,7 +267,6 @@ def test_tool_description_from_docstring():
def test_tool_ainvoke_with_exceptions(): def test_tool_ainvoke_with_exceptions():
async def async_function(input_data): async def async_function(input_data):
raise ValueError("Test exception") raise ValueError("Test exception")
@ -286,7 +279,6 @@ def test_tool_ainvoke_with_exceptions():
def test_structured_tool_infer_schema_false(): def test_structured_tool_infer_schema_false():
def sample_function(input_data): def sample_function(input_data):
return input_data return input_data
@ -300,7 +292,6 @@ def test_structured_tool_infer_schema_false():
def test_structured_tool_ainvoke_with_callbacks(): def test_structured_tool_ainvoke_with_callbacks():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -316,14 +307,15 @@ def test_structured_tool_ainvoke_with_callbacks():
args_schema=SampleArgsSchema, args_schema=SampleArgsSchema,
) )
callbacks = MagicMock() callbacks = MagicMock()
result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks) result = tool.ainvoke(
{"tool_input": "input_data"}, callbacks=callbacks
)
assert result == "input_data" assert result == "input_data"
callbacks.on_start.assert_called_once() callbacks.on_start.assert_called_once()
callbacks.on_finish.assert_called_once() callbacks.on_finish.assert_called_once()
def test_structured_tool_description_not_provided(): def test_structured_tool_description_not_provided():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -338,7 +330,6 @@ def test_structured_tool_description_not_provided():
def test_structured_tool_args_schema(): def test_structured_tool_args_schema():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -354,7 +345,6 @@ def test_structured_tool_args_schema():
def test_structured_tool_args_schema_inference(): def test_structured_tool_args_schema_inference():
def sample_function(input_data): def sample_function(input_data):
return input_data return input_data
@ -368,7 +358,6 @@ def test_structured_tool_args_schema_inference():
def test_structured_tool_ainvoke_with_new_argument(): def test_structured_tool_ainvoke_with_new_argument():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -380,12 +369,13 @@ def test_structured_tool_ainvoke_with_new_argument():
func=sample_function, func=sample_function,
args_schema=SampleArgsSchema, args_schema=SampleArgsSchema,
) )
result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None) result = tool.ainvoke(
{"tool_input": "input_data"}, callbacks=None
)
assert result == "input_data" assert result == "input_data"
def test_structured_tool_ainvoke_with_exceptions(): def test_structured_tool_ainvoke_with_exceptions():
class SampleArgsSchema: class SampleArgsSchema:
pass pass
@ -471,7 +461,9 @@ def test_tool_with_runnable(mock_runnable):
def test_tool_with_invalid_argument(): def test_tool_with_invalid_argument():
# Test passing an invalid argument type # Test passing an invalid argument type
with pytest.raises(ValueError): with pytest.raises(ValueError):
tool(123) # Using an integer instead of a string/callable/Runnable tool(
123
) # Using an integer instead of a string/callable/Runnable
def test_tool_with_multiple_arguments(mock_func): def test_tool_with_multiple_arguments(mock_func):
@ -533,7 +525,9 @@ class MockSchema(BaseModel):
# Test suite starts here # Test suite starts here
class TestTool: class TestTool:
# Basic Functionality Tests # Basic Functionality Tests
def test_tool_with_valid_callable_creates_base_tool(self, mock_func): def test_tool_with_valid_callable_creates_base_tool(
self, mock_func
):
result = tool(mock_func) result = tool(mock_func)
assert isinstance(result, BaseTool) assert isinstance(result, BaseTool)
@ -566,7 +560,6 @@ class TestTool:
# Error Handling Tests # Error Handling Tests
def test_tool_raises_error_without_docstring(self): def test_tool_raises_error_without_docstring(self):
def no_doc_func(arg: str) -> str: def no_doc_func(arg: str) -> str:
return arg return arg
@ -574,14 +567,14 @@ class TestTool:
tool(no_doc_func) tool(no_doc_func)
def test_tool_raises_error_runnable_without_object_schema( def test_tool_raises_error_runnable_without_object_schema(
self, mock_runnable): self, mock_runnable
):
with pytest.raises(ValueError): with pytest.raises(ValueError):
tool(mock_runnable) tool(mock_runnable)
# Decorator Behavior Tests # Decorator Behavior Tests
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_tool_function(self): async def test_async_tool_function(self):
@tool @tool
async def async_func(arg: str) -> str: async def async_func(arg: str) -> str:
return arg return arg
@ -604,7 +597,6 @@ class TestTool:
pass pass
def test_tool_with_different_return_types(self): def test_tool_with_different_return_types(self):
@tool @tool
def return_int(arg: str) -> int: def return_int(arg: str) -> int:
return int(arg) return int(arg)
@ -623,7 +615,6 @@ class TestTool:
# Test with multiple arguments # Test with multiple arguments
def test_tool_with_multiple_args(self): def test_tool_with_multiple_args(self):
@tool @tool
def concat_strings(a: str, b: str) -> str: def concat_strings(a: str, b: str) -> str:
return a + b return a + b
@ -633,7 +624,6 @@ class TestTool:
# Test handling of optional arguments # Test handling of optional arguments
def test_tool_with_optional_args(self): def test_tool_with_optional_args(self):
@tool @tool
def greet(name: str, greeting: str = "Hello") -> str: def greet(name: str, greeting: str = "Hello") -> str:
return f"{greeting} {name}" return f"{greeting} {name}"
@ -643,7 +633,6 @@ class TestTool:
# Test with variadic arguments # Test with variadic arguments
def test_tool_with_variadic_args(self): def test_tool_with_variadic_args(self):
@tool @tool
def sum_numbers(*numbers: int) -> int: def sum_numbers(*numbers: int) -> int:
return sum(numbers) return sum(numbers)
@ -653,7 +642,6 @@ class TestTool:
# Test with keyword arguments # Test with keyword arguments
def test_tool_with_kwargs(self): def test_tool_with_kwargs(self):
@tool @tool
def build_query(**kwargs) -> str: def build_query(**kwargs) -> str:
return "&".join(f"{k}={v}" for k, v in kwargs.items()) return "&".join(f"{k}={v}" for k, v in kwargs.items())
@ -663,7 +651,6 @@ class TestTool:
# Test with mixed types of arguments # Test with mixed types of arguments
def test_tool_with_mixed_args(self): def test_tool_with_mixed_args(self):
@tool @tool
def mixed_args(a: int, b: str, *args, **kwargs) -> str: def mixed_args(a: int, b: str, *args, **kwargs) -> str:
return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}" return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}"
@ -672,7 +659,6 @@ class TestTool:
# Test error handling with incorrect types # Test error handling with incorrect types
def test_tool_error_with_incorrect_types(self): def test_tool_error_with_incorrect_types(self):
@tool @tool
def add_numbers(a: int, b: int) -> int: def add_numbers(a: int, b: int) -> int:
return a + b return a + b
@ -682,7 +668,6 @@ class TestTool:
# Test with nested tools # Test with nested tools
def test_nested_tools(self): def test_nested_tools(self):
@tool @tool
def inner_tool(arg: str) -> str: def inner_tool(arg: str) -> str:
return f"Inner {arg}" return f"Inner {arg}"
@ -694,7 +679,6 @@ class TestTool:
assert outer_tool("Test") == "Outer Inner Test" assert outer_tool("Test") == "Outer Inner Test"
def test_tool_with_global_variable(self): def test_tool_with_global_variable(self):
@tool @tool
def access_global(arg: str) -> str: def access_global(arg: str) -> str:
return f"{global_var} {arg}" return f"{global_var} {arg}"
@ -717,7 +701,6 @@ class TestTool:
# Test with complex data structures # Test with complex data structures
def test_tool_with_complex_data_structures(self): def test_tool_with_complex_data_structures(self):
@tool @tool
def process_data(data: dict) -> list: def process_data(data: dict) -> list:
return [data[key] for key in sorted(data.keys())] return [data[key] for key in sorted(data.keys())]
@ -727,7 +710,6 @@ class TestTool:
# Test handling exceptions within the tool function # Test handling exceptions within the tool function
def test_tool_handling_internal_exceptions(self): def test_tool_handling_internal_exceptions(self):
@tool @tool
def function_that_raises(arg: str): def function_that_raises(arg: str):
if arg == "error": if arg == "error":
@ -740,7 +722,6 @@ class TestTool:
# Test with functions returning None # Test with functions returning None
def test_tool_with_none_return(self): def test_tool_with_none_return(self):
@tool @tool
def return_none(arg: str): def return_none(arg: str):
return None return None
@ -754,9 +735,7 @@ class TestTool:
# Test with class methods # Test with class methods
def test_tool_with_class_method(self): def test_tool_with_class_method(self):
class MyClass: class MyClass:
@tool @tool
def method(self, arg: str) -> str: def method(self, arg: str) -> str:
return f"Method {arg}" return f"Method {arg}"
@ -766,15 +745,12 @@ class TestTool:
# Test tool function with inheritance # Test tool function with inheritance
def test_tool_with_inheritance(self): def test_tool_with_inheritance(self):
class Parent: class Parent:
@tool @tool
def parent_method(self, arg: str) -> str: def parent_method(self, arg: str) -> str:
return f"Parent {arg}" return f"Parent {arg}"
class Child(Parent): class Child(Parent):
@tool @tool
def child_method(self, arg: str) -> str: def child_method(self, arg: str) -> str:
return f"Child {arg}" return f"Child {arg}"
@ -785,9 +761,7 @@ class TestTool:
# Test with decorators stacking # Test with decorators stacking
def test_tool_with_multiple_decorators(self): def test_tool_with_multiple_decorators(self):
def another_decorator(func): def another_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return f"Decorated {func(*args, **kwargs)}" return f"Decorated {func(*args, **kwargs)}"
@ -813,7 +787,9 @@ class TestTool:
def thread_target(): def thread_target():
results.append(threaded_function(5)) results.append(threaded_function(5))
threads = [threading.Thread(target=thread_target) for _ in range(10)] threads = [
threading.Thread(target=thread_target) for _ in range(10)
]
for t in threads: for t in threads:
t.start() t.start()
for t in threads: for t in threads:
@ -823,7 +799,6 @@ class TestTool:
# Test with recursive functions # Test with recursive functions
def test_tool_with_recursive_function(self): def test_tool_with_recursive_function(self):
@tool @tool
def recursive_function(n: int) -> int: def recursive_function(n: int) -> int:
if n == 0: if n == 0:

@ -19,8 +19,9 @@ def test_check_device_no_cuda(monkeypatch):
def test_check_device_cuda_exception(monkeypatch): def test_check_device_cuda_exception(monkeypatch):
# Mock torch.cuda.is_available to raise an exception # Mock torch.cuda.is_available to raise an exception
monkeypatch.setattr(torch.cuda, "is_available", monkeypatch.setattr(
lambda: 1 / 0) # Raises ZeroDivisionError torch.cuda, "is_available", lambda: 1 / 0
) # Raises ZeroDivisionError
result = check_device(log_level=logging.DEBUG) result = check_device(log_level=logging.DEBUG)
assert result.type == "cpu" assert result.type == "cpu"
@ -32,8 +33,12 @@ def test_check_device_one_cuda(monkeypatch):
# Mock torch.cuda.device_count to return 1 # Mock torch.cuda.device_count to return 1
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
# Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0
monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) monkeypatch.setattr(
monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) torch.cuda, "memory_allocated", lambda device: 0
)
monkeypatch.setattr(
torch.cuda, "memory_reserved", lambda device: 0
)
result = check_device(log_level=logging.DEBUG) result = check_device(log_level=logging.DEBUG)
assert len(result) == 1 assert len(result) == 1
@ -47,8 +52,12 @@ def test_check_device_multiple_cuda(monkeypatch):
# Mock torch.cuda.device_count to return 4 # Mock torch.cuda.device_count to return 4
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
# Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0
monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) monkeypatch.setattr(
monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) torch.cuda, "memory_allocated", lambda device: 0
)
monkeypatch.setattr(
torch.cuda, "memory_reserved", lambda device: 0
)
result = check_device(log_level=logging.DEBUG) result = check_device(log_level=logging.DEBUG)
assert len(result) == 4 assert len(result) == 4

@ -19,7 +19,8 @@ def test_print_class_parameters_agent():
# Replace with the expected output for Agent class # Replace with the expected output for Agent class
expected_output = ( expected_output = (
"Parameter: name, Type: <class 'str'>\nParameter: age, Type:" "Parameter: name, Type: <class 'str'>\nParameter: age, Type:"
" <class 'int'>") " <class 'int'>"
)
assert output == expected_output assert output == expected_output
@ -32,7 +33,9 @@ def test_print_class_parameters_error():
def get_parameters(class_name: str): def get_parameters(class_name: str):
classes = {"Agent": Agent} classes = {"Agent": Agent}
if class_name in classes: if class_name in classes:
return print_class_parameters(classes[class_name], api_format=True) return print_class_parameters(
classes[class_name], api_format=True
)
else: else:
return {"error": "Class not found"} return {"error": "Class not found"}

@ -32,14 +32,18 @@ def test_multiple_gpus_available(mocker):
def test_device_properties(mocker): def test_device_properties(mocker):
mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.is_available", return_value=True)
mocker.patch("torch.cuda.device_count", return_value=1) mocker.patch("torch.cuda.device_count", return_value=1)
mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) mocker.patch(
"torch.cuda.get_device_capability", return_value=(7, 5)
)
mocker.patch( mocker.patch(
"torch.cuda.get_device_properties", "torch.cuda.get_device_properties",
return_value=MagicMock(total_memory=1000), return_value=MagicMock(total_memory=1000),
) )
mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_allocated", return_value=200)
mocker.patch("torch.cuda.memory_reserved", return_value=300) mocker.patch("torch.cuda.memory_reserved", return_value=300)
mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") mocker.patch(
"torch.cuda.get_device_name", return_value="Tesla K80"
)
devices = check_device() devices = check_device()
assert len(devices) == 1 assert len(devices) == 1
assert str(devices[0]) == "cuda" assert str(devices[0]) == "cuda"
@ -48,21 +52,27 @@ def test_device_properties(mocker):
def test_memory_threshold(mocker): def test_memory_threshold(mocker):
mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.is_available", return_value=True)
mocker.patch("torch.cuda.device_count", return_value=1) mocker.patch("torch.cuda.device_count", return_value=1)
mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) mocker.patch(
"torch.cuda.get_device_capability", return_value=(7, 5)
)
mocker.patch( mocker.patch(
"torch.cuda.get_device_properties", "torch.cuda.get_device_properties",
return_value=MagicMock(total_memory=1000), return_value=MagicMock(total_memory=1000),
) )
mocker.patch("torch.cuda.memory_allocated", mocker.patch(
return_value=900) # 90% of total memory "torch.cuda.memory_allocated", return_value=900
) # 90% of total memory
mocker.patch("torch.cuda.memory_reserved", return_value=300) mocker.patch("torch.cuda.memory_reserved", return_value=300)
mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") mocker.patch(
"torch.cuda.get_device_name", return_value="Tesla K80"
)
with pytest.warns( with pytest.warns(
UserWarning, UserWarning,
match=r"Memory usage for device cuda exceeds threshold", match=r"Memory usage for device cuda exceeds threshold",
): ):
devices = check_device( devices = check_device(
memory_threshold=0.8) # Set memory threshold to 80% memory_threshold=0.8
) # Set memory threshold to 80%
assert len(devices) == 1 assert len(devices) == 1
assert str(devices[0]) == "cuda" assert str(devices[0]) == "cuda"
@ -70,21 +80,27 @@ def test_memory_threshold(mocker):
def test_compute_capability_threshold(mocker): def test_compute_capability_threshold(mocker):
mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.is_available", return_value=True)
mocker.patch("torch.cuda.device_count", return_value=1) mocker.patch("torch.cuda.device_count", return_value=1)
mocker.patch("torch.cuda.get_device_capability", mocker.patch(
return_value=(3, 0)) # Compute capability 3.0 "torch.cuda.get_device_capability", return_value=(3, 0)
) # Compute capability 3.0
mocker.patch( mocker.patch(
"torch.cuda.get_device_properties", "torch.cuda.get_device_properties",
return_value=MagicMock(total_memory=1000), return_value=MagicMock(total_memory=1000),
) )
mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_allocated", return_value=200)
mocker.patch("torch.cuda.memory_reserved", return_value=300) mocker.patch("torch.cuda.memory_reserved", return_value=300)
mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") mocker.patch(
"torch.cuda.get_device_name", return_value="Tesla K80"
)
with pytest.warns( with pytest.warns(
UserWarning, UserWarning,
match=(r"Compute capability for device cuda is below threshold"), match=(
r"Compute capability for device cuda is below threshold"
),
): ):
devices = check_device( devices = check_device(
capability_threshold=3.5) # Set compute capability threshold to 3.5 capability_threshold=3.5
) # Set compute capability threshold to 3.5
assert len(devices) == 1 assert len(devices) == 1
assert str(devices[0]) == "cuda" assert str(devices[0]) == "cuda"

@ -14,7 +14,8 @@ def test_basic_message():
with mock.patch.object(Console, "print") as mock_print: with mock.patch.object(Console, "print") as mock_print:
display_markdown_message("This is a test") display_markdown_message("This is a test")
mock_print.assert_called_once_with( mock_print.assert_called_once_with(
Markdown("This is a test", style="cyan")) Markdown("This is a test", style="cyan")
)
def test_empty_message(): def test_empty_message():
@ -30,7 +31,8 @@ def test_colors(color):
with mock.patch.object(Console, "print") as mock_print: with mock.patch.object(Console, "print") as mock_print:
display_markdown_message("This is a test", color) display_markdown_message("This is a test", color)
mock_print.assert_called_once_with( mock_print.assert_called_once_with(
Markdown("This is a test", style=color)) Markdown("This is a test", style=color)
)
def test_dash_line(): def test_dash_line():

@ -22,8 +22,12 @@ def markdown_content_without_code():
""" """
def test_extract_code_from_markdown_with_code(markdown_content_with_code,): def test_extract_code_from_markdown_with_code(
extracted_code = extract_code_from_markdown(markdown_content_with_code) markdown_content_with_code,
):
extracted_code = extract_code_from_markdown(
markdown_content_with_code
)
assert "def my_func():" in extracted_code assert "def my_func():" in extracted_code
assert 'print("This is my function.")' in extracted_code assert 'print("This is my function.")' in extracted_code
assert "class MyClass:" in extracted_code assert "class MyClass:" in extracted_code
@ -31,8 +35,11 @@ def test_extract_code_from_markdown_with_code(markdown_content_with_code,):
def test_extract_code_from_markdown_without_code( def test_extract_code_from_markdown_without_code(
markdown_content_without_code,): markdown_content_without_code,
extracted_code = extract_code_from_markdown(markdown_content_without_code) ):
extracted_code = extract_code_from_markdown(
markdown_content_without_code
)
assert extracted_code == "" assert extracted_code == ""

@ -8,8 +8,12 @@ from swarms.utils import find_image_path
def test_find_image_path_no_images(): def test_find_image_path_no_images():
assert (find_image_path("This is a test string without any image paths.") assert (
is None) find_image_path(
"This is a test string without any image paths."
)
is None
)
def test_find_image_path_one_image(): def test_find_image_path_one_image():
@ -19,7 +23,9 @@ def test_find_image_path_one_image():
def test_find_image_path_multiple_images(): def test_find_image_path_multiple_images():
text = "This string has two image paths: img1.png, and img2.jpg." text = "This string has two image paths: img1.png, and img2.jpg."
assert (find_image_path(text) == "img2.jpg") # Assuming both images exist assert (
find_image_path(text) == "img2.jpg"
) # Assuming both images exist
def test_find_image_path_wrong_input(): def test_find_image_path_wrong_input():

@ -4,11 +4,14 @@ from swarms.utils import limit_tokens_from_string
def test_limit_tokens_from_string(): def test_limit_tokens_from_string():
sentence = ("This is a test sentence. It is used for testing the number" sentence = (
" of tokens.") "This is a test sentence. It is used for testing the number"
" of tokens."
)
limited = limit_tokens_from_string(sentence, limit=5) limited = limit_tokens_from_string(sentence, limit=5)
assert (len(limited.split()) assert (
<= 5), "The output string has more than 5 tokens." len(limited.split()) <= 5
), "The output string has more than 5 tokens."
def test_limit_zero_tokens(): def test_limit_zero_tokens():
@ -18,21 +21,26 @@ def test_limit_zero_tokens():
def test_negative_token_limit(): def test_negative_token_limit():
sentence = ("This test will raise an exception when limit is negative.") sentence = (
"This test will raise an exception when limit is negative."
)
with pytest.raises(Exception): with pytest.raises(Exception):
limit_tokens_from_string(sentence, limit=-1) limit_tokens_from_string(sentence, limit=-1)
@pytest.mark.parametrize("sentence, model", @pytest.mark.parametrize(
[("Some sentence", "unavailable-model")]) "sentence, model", [("Some sentence", "unavailable-model")]
)
def test_unknown_model(sentence, model): def test_unknown_model(sentence, model):
with pytest.raises(Exception): with pytest.raises(Exception):
limit_tokens_from_string(sentence, model=model) limit_tokens_from_string(sentence, model=model)
def test_string_token_limit_exceeded(): def test_string_token_limit_exceeded():
sentence = ("This is a long sentence with more than twenty tokens which" sentence = (
"This is a long sentence with more than twenty tokens which"
" is used for testing. It checks whether the function" " is used for testing. It checks whether the function"
" correctly limits the tokens to a specified amount.") " correctly limits the tokens to a specified amount."
)
limited = limit_tokens_from_string(sentence, limit=20) limited = limit_tokens_from_string(sentence, limit=20)
assert len(limited.split()) <= 20, "The token limit is exceeded." assert len(limited.split()) <= 20, "The token limit is exceeded."

@ -6,7 +6,6 @@ from swarms.utils import load_model_torch
class DummyModel(nn.Module): class DummyModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(10, 2) self.fc = nn.Linear(10, 2)
@ -26,7 +25,9 @@ def test_load_model_torch_success(tmp_path):
model_loaded = load_model_torch(model_path, model=DummyModel()) model_loaded = load_model_torch(model_path, model=DummyModel())
# Check if loaded model has the same architecture # Check if loaded model has the same architecture
assert isinstance(model_loaded, DummyModel), "Loaded model type mismatch." assert isinstance(
model_loaded, DummyModel
), "Loaded model type mismatch."
# Test case 2: Test if function raises FileNotFoundError for non-existent file # Test case 2: Test if function raises FileNotFoundError for non-existent file
@ -65,11 +66,12 @@ def test_load_model_torch_device_handling(tmp_path):
# Define a device other than default and load the model to the specified device # Define a device other than default and load the model to the specified device
device = torch.device("cpu") device = torch.device("cpu")
model_loaded = load_model_torch(model_path, model_loaded = load_model_torch(
model=DummyModel(), model_path, model=DummyModel(), device=device
device=device) )
assert (model_loaded.fc.weight.device == device assert (
model_loaded.fc.weight.device == device
), "Model not loaded to specified device." ), "Model not loaded to specified device."
@ -80,14 +82,15 @@ def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path):
torch.save(model.state_dict(), model_path) torch.save(model.state_dict(), model_path)
def mock_torch_load(*args, **kwargs): def mock_torch_load(*args, **kwargs):
assert ("pickle_module" assert (
in kwargs), "Keyword arguments not passed to 'torch.load'." "pickle_module" in kwargs
), "Keyword arguments not passed to 'torch.load'."
# Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly # Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly
monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr(torch, "load", mock_torch_load)
load_model_torch(model_path, load_model_torch(
model=DummyModel(), model_path, model=DummyModel(), pickle_module="dummy_module"
pickle_module="dummy_module") )
# Test case 7: Test for model loading on CPU if no GPU is available # Test case 7: Test for model loading on CPU if no GPU is available
@ -100,8 +103,9 @@ def test_load_model_torch_cpu(tmp_path):
return False return False
# Monkeypatch to simulate no GPU available # Monkeypatch to simulate no GPU available
pytest.MonkeyPatch.setattr(torch.cuda, "is_available", pytest.MonkeyPatch.setattr(
mock_torch_cuda_is_available) torch.cuda, "is_available", mock_torch_cuda_is_available
)
model_loaded = load_model_torch(model_path, model=DummyModel()) model_loaded = load_model_torch(model_path, model=DummyModel())
# Ensure model is loaded on CPU # Ensure model is loaded on CPU

@ -42,13 +42,15 @@ def test_load_model_torch_model_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module) mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"}) mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model) load_model_torch("model_path", model=mock_model)
mock_model.load_state_dict.assert_called_once_with({"key": "value"}, mock_model.load_state_dict.assert_called_once_with(
strict=True) {"key": "value"}, strict=True
)
def test_load_model_torch_model_specified_strict_false(mocker): def test_load_model_torch_model_specified_strict_false(mocker):
mock_model = MagicMock(spec=torch.nn.Module) mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"}) mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model, strict=False) load_model_torch("model_path", model=mock_model, strict=False)
mock_model.load_state_dict.assert_called_once_with({"key": "value"}, mock_model.load_state_dict.assert_called_once_with(
strict=False) {"key": "value"}, strict=False
)

@ -18,7 +18,6 @@ def func2_with_exception(x):
def test_same_results_no_exception(caplog): def test_same_results_no_exception(caplog):
@math_eval(func1_no_exception, func2_no_exception) @math_eval(func1_no_exception, func2_no_exception)
def test_func(x): def test_func(x):
return x return x
@ -29,7 +28,6 @@ def test_same_results_no_exception(caplog):
def test_func1_exception(caplog): def test_func1_exception(caplog):
@math_eval(func1_with_exception, func2_no_exception) @math_eval(func1_with_exception, func2_no_exception)
def test_func(x): def test_func(x):
return x return x

@ -10,7 +10,6 @@ from swarms.utils import metrics_decorator
# Basic successful test # Basic successful test
def test_metrics_decorator_success(): def test_metrics_decorator_success():
@metrics_decorator @metrics_decorator
def decorated_func(): def decorated_func():
time.sleep(0.1) time.sleep(0.1)
@ -31,8 +30,8 @@ def test_metrics_decorator_success():
], ],
) )
def test_metrics_decorator_with_various_wait_times_and_return_vals( def test_metrics_decorator_with_various_wait_times_and_return_vals(
wait_time, return_val): wait_time, return_val
):
@metrics_decorator @metrics_decorator
def decorated_func(): def decorated_func():
time.sleep(wait_time) time.sleep(wait_time)
@ -56,17 +55,19 @@ def test_metrics_decorator_with_mocked_time(mocker):
return ["tok_1", "tok_2"] return ["tok_1", "tok_2"]
metrics = decorated_func() metrics = decorated_func()
assert (metrics == """ assert (
metrics
== """
Time to First Token: 5 Time to First Token: 5
Generation Latency: 20 Generation Latency: 20
Throughput: 0.1 Throughput: 0.1
""") """
)
mocked_time.assert_any_call() mocked_time.assert_any_call()
# Test to ensure that exceptions in the decorated function are propagated # Test to ensure that exceptions in the decorated function are propagated
def test_metrics_decorator_raises_exception(): def test_metrics_decorator_raises_exception():
@metrics_decorator @metrics_decorator
def decorated_func(): def decorated_func():
raise ValueError("Oops!") raise ValueError("Oops!")
@ -77,7 +78,6 @@ def test_metrics_decorator_raises_exception():
# Test to ensure proper handling when decorated function returns non-list value # Test to ensure proper handling when decorated function returns non-list value
def test_metrics_decorator_with_non_list_return_val(): def test_metrics_decorator_with_non_list_return_val():
@metrics_decorator @metrics_decorator
def decorated_func(): def decorated_func():
return "Hello, world!" return "Hello, world!"

@ -9,13 +9,16 @@ from swarms.utils import prep_torch_inference
def test_prep_torch_inference(): def test_prep_torch_inference():
model_path = "model_path" model_path = "model_path"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
model_mock = Mock() model_mock = Mock()
model_mock.eval = Mock() model_mock.eval = Mock()
# Mocking the load_model_torch function to return our mock model. # Mocking the load_model_torch function to return our mock model.
with unittest.mock.patch("swarms.utils.load_model_torch", with unittest.mock.patch(
return_value=model_mock) as _: "swarms.utils.load_model_torch", return_value=model_mock
) as _:
model = prep_torch_inference(model_path, device) model = prep_torch_inference(model_path, device)
# Check if model was properly loaded and eval function was called # Check if model was properly loaded and eval function was called

@ -3,7 +3,8 @@ from unittest.mock import MagicMock
import torch import torch
from swarms.utils.prep_torch_model_inference import ( from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,) prep_torch_inference,
)
def test_prep_torch_inference_no_model_path(): def test_prep_torch_inference_no_model_path():

@ -4,30 +4,27 @@ from swarms.utils import print_class_parameters
class TestObject: class TestObject:
def __init__(self, value1, value2: int): def __init__(self, value1, value2: int):
pass pass
class TestObject2: class TestObject2:
def __init__(self: "TestObject2", value1, value2: int = 5): def __init__(self: "TestObject2", value1, value2: int = 5):
pass pass
def test_class_with_complex_parameters(): def test_class_with_complex_parameters():
class ComplexArgs: class ComplexArgs:
def __init__(self, value1: list, value2: dict = {}): def __init__(self, value1: list, value2: dict = {}):
pass pass
output = {"value1": "<class 'list'>", "value2": "<class 'dict'>"} output = {"value1": "<class 'list'>", "value2": "<class 'dict'>"}
assert (print_class_parameters(ComplexArgs, api_format=True) == output) assert (
print_class_parameters(ComplexArgs, api_format=True) == output
)
def test_empty_class(): def test_empty_class():
class Empty: class Empty:
pass pass
@ -36,9 +33,7 @@ def test_empty_class():
def test_class_with_no_annotations(): def test_class_with_no_annotations():
class NoAnnotations: class NoAnnotations:
def __init__(self, value1, value2): def __init__(self, value1, value2):
pass pass
@ -46,13 +41,14 @@ def test_class_with_no_annotations():
"value1": "<class 'inspect._empty'>", "value1": "<class 'inspect._empty'>",
"value2": "<class 'inspect._empty'>", "value2": "<class 'inspect._empty'>",
} }
assert (print_class_parameters(NoAnnotations, api_format=True) == output) assert (
print_class_parameters(NoAnnotations, api_format=True)
== output
)
def test_class_with_partial_annotations(): def test_class_with_partial_annotations():
class PartialAnnotations: class PartialAnnotations:
def __init__(self, value1, value2: int): def __init__(self, value1, value2: int):
pass pass
@ -60,8 +56,10 @@ def test_class_with_partial_annotations():
"value1": "<class 'inspect._empty'>", "value1": "<class 'inspect._empty'>",
"value2": "<class 'int'>", "value2": "<class 'int'>",
} }
assert (print_class_parameters(PartialAnnotations, assert (
api_format=True) == output) print_class_parameters(PartialAnnotations, api_format=True)
== output
)
@pytest.mark.parametrize( @pytest.mark.parametrize(

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save