run pre-commit code_quality check

pull/408/head
Joshua David 1 year ago
parent 36c295ed8b
commit 5b121a7445

2
.gitignore vendored

@ -110,7 +110,7 @@ docs/_build/
# PyBuilder
.pybuilder/
target/
`
# Jupyter Notebook
.ipynb_checkpoints

@ -33,7 +33,7 @@ CODE
"""
# Initialize the language model
llm = OpenAIChat(openai_api_key=api_key, max_tokens=5000)
llm = OpenAIChat(openai_api_key=api_key, max_tokens=4096)
# Documentation agent

@ -3,7 +3,7 @@ from pathlib import Path
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma

@ -23,15 +23,19 @@ def test_multion_agent_run(mock_multion):
assert result == "result"
assert status == "status"
assert last_url == "lastUrl"
mock_multion.browse.assert_called_once_with({
"cmd": "task",
"url": "https://www.example.com",
"maxSteps": 5,
})
mock_multion.browse.assert_called_once_with(
{
"cmd": "task",
"url": "https://www.example.com",
"maxSteps": 5,
}
)
# Additional tests for different tasks
@pytest.mark.parametrize("task", ["task1", "task2", "task3", "task4", "task5"])
@pytest.mark.parametrize(
"task", ["task1", "task2", "task3", "task4", "task5"]
)
@patch("swarms.agents.multion_agent.multion")
def test_multion_agent_run_different_tasks(mock_multion, task):
mock_response = MagicMock()
@ -50,8 +54,6 @@ def test_multion_agent_run_different_tasks(mock_multion, task):
assert result == "result"
assert status == "status"
assert last_url == "lastUrl"
mock_multion.browse.assert_called_once_with({
"cmd": task,
"url": "https://www.example.com",
"maxSteps": 5
})
mock_multion.browse.assert_called_once_with(
{"cmd": task, "url": "https://www.example.com", "maxSteps": 5}
)

@ -11,27 +11,18 @@ def test_tool_agent_init():
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
"name": {"type": "string"},
"age": {"type": "number"},
"is_student": {"type": "boolean"},
"courses": {"type": "array", "items": {"type": "string"}},
},
}
name = "Test Agent"
description = "This is a test agent"
agent = ToolAgent(name, description, model, tokenizer, json_schema)
agent = ToolAgent(
name, description, model, tokenizer, json_schema
)
assert agent.name == name
assert agent.description == description
@ -47,29 +38,22 @@ def test_tool_agent_run(mock_run):
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
"name": {"type": "string"},
"age": {"type": "number"},
"is_student": {"type": "boolean"},
"courses": {"type": "array", "items": {"type": "string"}},
},
}
name = "Test Agent"
description = "This is a test agent"
task = ("Generate a person's information based on the following"
" schema:")
task = (
"Generate a person's information based on the following"
" schema:"
)
agent = ToolAgent(name, description, model, tokenizer, json_schema)
agent = ToolAgent(
name, description, model, tokenizer, json_schema
)
agent.run(task)
mock_run.assert_called_once_with(task)
@ -81,21 +65,10 @@ def test_tool_agent_init_with_kwargs():
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
"name": {"type": "string"},
"age": {"type": "number"},
"is_student": {"type": "boolean"},
"courses": {"type": "array", "items": {"type": "string"}},
},
}
name = "Test Agent"
@ -109,8 +82,9 @@ def test_tool_agent_init_with_kwargs():
"max_string_token_length": 20,
}
agent = ToolAgent(name, description, model, tokenizer, json_schema,
**kwargs)
agent = ToolAgent(
name, description, model, tokenizer, json_schema, **kwargs
)
assert agent.name == name
assert agent.description == description
@ -121,4 +95,7 @@ def test_tool_agent_init_with_kwargs():
assert agent.max_array_length == kwargs["max_array_length"]
assert agent.max_number_tokens == kwargs["max_number_tokens"]
assert agent.temperature == kwargs["temperature"]
assert (agent.max_string_token_length == kwargs["max_string_token_length"])
assert (
agent.max_string_token_length
== kwargs["max_string_token_length"]
)

@ -33,8 +33,9 @@ def test_memory_limit_enforced(memory):
# Parameterized Tests
@pytest.mark.parametrize("scores, best_score", [([10, 5, 3], 10),
([1, 2, 3], 3)])
@pytest.mark.parametrize(
"scores, best_score", [([10, 5, 3], 10), ([1, 2, 3], 3)]
)
def test_get_top_n(scores, best_score, memory):
for score in scores:
memory.add(score, {"data": f"test{score}"})

@ -26,7 +26,8 @@ def memory_instance(memory_file):
def test_init(memory_file):
memory = DictSharedMemory(file_loc=memory_file)
assert os.path.exists(
memory.file_loc), "Memory file should be created if non-existent"
memory.file_loc
), "Memory file should be created if non-existent"
def test_add_entry(memory_instance):
@ -43,8 +44,9 @@ def test_get_top_n(memory_instance):
memory_instance.add(9.5, "agent123", 1, "Entry A")
memory_instance.add(8.5, "agent124", 1, "Entry B")
top_1 = memory_instance.get_top_n(1)
assert (len(top_1) == 1
), "get_top_n should return the correct number of top entries"
assert (
len(top_1) == 1
), "get_top_n should return the correct number of top entries"
# Parameterized tests
@ -57,14 +59,18 @@ def test_get_top_n(memory_instance):
# add more test cases
],
)
def test_parametrized_get_top_n(memory_instance, scores, agent_ids,
expected_top_score):
def test_parametrized_get_top_n(
memory_instance, scores, agent_ids, expected_top_score
):
for score, agent_id in zip(scores, agent_ids):
memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}")
memory_instance.add(
score, agent_id, 1, f"Entry by {agent_id}"
)
top_1 = memory_instance.get_top_n(1)
top_score = next(iter(top_1.values()))["score"]
assert (top_score == expected_top_score
), "get_top_n should return the entry with top score"
assert (
top_score == expected_top_score
), "get_top_n should return the entry with top score"
# Exception testing
@ -72,7 +78,9 @@ def test_parametrized_get_top_n(memory_instance, scores, agent_ids,
def test_add_entry_invalid_input(memory_instance):
with pytest.raises(ValueError):
memory_instance.add("invalid_score", "agent123", 1, "Test Entry")
memory_instance.add(
"invalid_score", "agent123", 1, "Test Entry"
)
# Mocks and monkey-patching

@ -35,13 +35,16 @@ def qa_mock():
# Example test cases
def test_initialization_default_settings(vector_memory):
assert vector_memory.chunk_size == 1000
assert (vector_memory.chunk_overlap == 100
) # assuming default overlap of 0.1
assert (
vector_memory.chunk_overlap == 100
) # assuming default overlap of 0.1
assert vector_memory.loc.exists()
def test_add_entry(vector_memory, embeddings_mock):
with patch.object(vector_memory.db, "add_texts") as add_texts_mock:
with patch.object(
vector_memory.db, "add_texts"
) as add_texts_mock:
vector_memory.add("Example text")
add_texts_mock.assert_called()
@ -74,17 +77,20 @@ def test_ask_question_returns_string(vector_memory, qa_mock):
), # Mocked object as a placeholder
],
)
def test_search_memory_different_params(vector_memory, query, k, type,
expected):
def test_search_memory_different_params(
vector_memory, query, k, type, expected
):
with patch.object(
vector_memory.db,
"max_marginal_relevance_search",
return_value=expected,
vector_memory.db,
"max_marginal_relevance_search",
return_value=expected,
):
with patch.object(
vector_memory.db,
"similarity_search_with_score",
return_value=expected,
vector_memory.db,
"similarity_search_with_score",
return_value=expected,
):
result = vector_memory.search_memory(query, k=k, type=type)
result = vector_memory.search_memory(
query, k=k, type=type
)
assert len(result) == (k if k > 0 else 0)

@ -8,7 +8,8 @@ api_key = os.getenv("PINECONE_API_KEY") or ""
def test_init():
with patch("pinecone.init") as MockInit, patch(
"pinecone.Index") as MockIndex:
"pinecone.Index"
) as MockIndex:
store = PineconeDB(
api_key=api_key,
index_name="test_index",
@ -70,7 +71,8 @@ def test_query():
def test_create_index():
with patch("pinecone.init"), patch("pinecone.Index"), patch(
"pinecone.create_index") as MockCreateIndex:
"pinecone.create_index"
) as MockCreateIndex:
store = PineconeDB(
api_key=api_key,
index_name="test_index",

@ -32,7 +32,8 @@ def test_create_vector_model():
def test_add_or_update_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session") as MockSession:
"sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
@ -50,7 +51,8 @@ def test_add_or_update_vector():
def test_query_vectors():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session") as MockSession:
"sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
@ -65,7 +67,8 @@ def test_query_vectors():
def test_delete_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session") as MockSession:
"sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",

@ -13,8 +13,9 @@ def mock_qdrant_client():
@pytest.fixture
def mock_sentence_transformer():
with patch("sentence_transformers.SentenceTransformer"
) as MockSentenceTransformer:
with patch(
"sentence_transformers.SentenceTransformer"
) as MockSentenceTransformer:
yield MockSentenceTransformer()
@ -28,7 +29,9 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client):
assert qdrant_client.client is not None
def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
def test_load_embedding_model(
qdrant_client, mock_sentence_transformer
):
qdrant_client._load_embedding_model("model_name")
mock_sentence_transformer.assert_called_once_with("model_name")
@ -36,7 +39,8 @@ def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
def test_setup_collection(qdrant_client, mock_qdrant_client):
qdrant_client._setup_collection()
mock_qdrant_client.get_collection.assert_called_once_with(
qdrant_client.collection_name)
qdrant_client.collection_name
)
def test_add_vectors(qdrant_client, mock_qdrant_client):

@ -12,29 +12,26 @@ def test_init():
def test_add():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
assert memory.short_term_memory == [{
"role": "user",
"message": "Hello, world!"
}]
assert memory.short_term_memory == [
{"role": "user", "message": "Hello, world!"}
]
def test_get_short_term():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
assert memory.get_short_term() == [{
"role": "user",
"message": "Hello, world!"
}]
assert memory.get_short_term() == [
{"role": "user", "message": "Hello, world!"}
]
def test_get_medium_term():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
memory.move_to_medium_term(0)
assert memory.get_medium_term() == [{
"role": "user",
"message": "Hello, world!"
}]
assert memory.get_medium_term() == [
{"role": "user", "message": "Hello, world!"}
]
def test_clear_medium_term():
@ -48,18 +45,19 @@ def test_clear_medium_term():
def test_get_short_term_memory_str():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
assert (memory.get_short_term_memory_str() ==
"[{'role': 'user', 'message': 'Hello, world!'}]")
assert (
memory.get_short_term_memory_str()
== "[{'role': 'user', 'message': 'Hello, world!'}]"
)
def test_update_short_term():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
memory.update_short_term(0, "user", "Goodbye, world!")
assert memory.get_short_term() == [{
"role": "user",
"message": "Goodbye, world!"
}]
assert memory.get_short_term() == [
{"role": "user", "message": "Goodbye, world!"}
]
def test_clear():
@ -73,10 +71,9 @@ def test_search_memory():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
assert memory.search_memory("Hello") == {
"short_term": [(0, {
"role": "user",
"message": "Hello, world!"
})],
"short_term": [
(0, {"role": "user", "message": "Hello, world!"})
],
"medium_term": [],
}
@ -84,18 +81,19 @@ def test_search_memory():
def test_return_shortmemory_as_str():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
assert (memory.return_shortmemory_as_str() ==
"[{'role': 'user', 'message': 'Hello, world!'}]")
assert (
memory.return_shortmemory_as_str()
== "[{'role': 'user', 'message': 'Hello, world!'}]"
)
def test_move_to_medium_term():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
memory.move_to_medium_term(0)
assert memory.get_medium_term() == [{
"role": "user",
"message": "Hello, world!"
}]
assert memory.get_medium_term() == [
{"role": "user", "message": "Hello, world!"}
]
assert memory.get_short_term() == []
@ -103,8 +101,10 @@ def test_return_medium_memory_as_str():
memory = ShortTermMemory()
memory.add("user", "Hello, world!")
memory.move_to_medium_term(0)
assert (memory.return_medium_memory_as_str() ==
"[{'role': 'user', 'message': 'Hello, world!'}]")
assert (
memory.return_medium_memory_as_str()
== "[{'role': 'user', 'message': 'Hello, world!'}]"
)
def test_thread_safety():
@ -114,7 +114,9 @@ def test_thread_safety():
for _ in range(1000):
memory.add("user", "Hello, world!")
threads = [threading.Thread(target=add_messages) for _ in range(10)]
threads = [
threading.Thread(target=add_messages) for _ in range(10)
]
for thread in threads:
thread.start()
for thread in threads:

@ -8,7 +8,9 @@ from swarms.memory.sqlite import SQLiteDB
@pytest.fixture
def db():
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
conn.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
)
conn.commit()
return SQLiteDB(":memory:")
@ -28,7 +30,9 @@ def test_delete(db):
def test_update(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test"))
db.update(
"UPDATE test SET name = ? WHERE name = ?", ("new", "test")
)
result = db.query("SELECT * FROM test")
assert result == [(1, "new")]
@ -41,7 +45,9 @@ def test_query(db):
def test_execute_query(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.execute_query("SELECT * FROM test WHERE name = ?", ("test",))
result = db.execute_query(
"SELECT * FROM test WHERE name = ?", ("test",)
)
assert result == [(1, "test")]
@ -95,4 +101,6 @@ def test_query_with_wrong_query(db):
def test_execute_query_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",))
db.execute_query(
"SELECT * FROM wrong WHERE name = ?", ("test",)
)

@ -16,7 +16,9 @@ def weaviate_client_mock():
grpc_port="mock_grpc_port",
grpc_secure=False,
auth_client_secret="mock_api_key",
additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"},
additional_headers={
"X-OpenAI-Api-Key": "mock_openai_api_key"
},
additional_config=Mock(),
)
@ -34,15 +36,13 @@ def weaviate_client_mock():
# Define tests for the WeaviateDB class
def test_create_collection(weaviate_client_mock):
# Test creating a collection
weaviate_client_mock.create_collection("test_collection", [{
"name": "property"
}])
weaviate_client_mock.create_collection(
"test_collection", [{"name": "property"}]
)
weaviate_client_mock.client.collections.create.assert_called_with(
name="test_collection",
vectorizer_config=None,
properties=[{
"name": "property"
}],
properties=[{"name": "property"}],
)
@ -51,9 +51,11 @@ def test_add_object(weaviate_client_mock):
properties = {"name": "John"}
weaviate_client_mock.add("test_collection", properties)
weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection")
"test_collection"
)
weaviate_client_mock.client.collections.data.insert.assert_called_with(
properties)
properties
)
def test_query_objects(weaviate_client_mock):
@ -61,20 +63,26 @@ def test_query_objects(weaviate_client_mock):
query = "name:John"
weaviate_client_mock.query("test_collection", query)
weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection")
"test_collection"
)
weaviate_client_mock.client.collections.query.bm25.assert_called_with(
query=query, limit=10)
query=query, limit=10
)
def test_update_object(weaviate_client_mock):
# Test updating an object
object_id = "12345"
properties = {"name": "Jane"}
weaviate_client_mock.update("test_collection", object_id, properties)
weaviate_client_mock.update(
"test_collection", object_id, properties
)
weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection")
"test_collection"
)
weaviate_client_mock.client.collections.data.update.assert_called_with(
object_id, properties)
object_id, properties
)
def test_delete_object(weaviate_client_mock):
@ -82,23 +90,25 @@ def test_delete_object(weaviate_client_mock):
object_id = "12345"
weaviate_client_mock.delete("test_collection", object_id)
weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection")
"test_collection"
)
weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with(
object_id)
object_id
)
def test_create_collection_with_vectorizer_config(weaviate_client_mock,):
def test_create_collection_with_vectorizer_config(
weaviate_client_mock,
):
# Test creating a collection with vectorizer configuration
vectorizer_config = {"config_key": "config_value"}
weaviate_client_mock.create_collection("test_collection", [{
"name": "property"
}], vectorizer_config)
weaviate_client_mock.create_collection(
"test_collection", [{"name": "property"}], vectorizer_config
)
weaviate_client_mock.client.collections.create.assert_called_with(
name="test_collection",
vectorizer_config=vectorizer_config,
properties=[{
"name": "property"
}],
properties=[{"name": "property"}],
)
@ -108,9 +118,11 @@ def test_query_objects_with_limit(weaviate_client_mock):
limit = 20
weaviate_client_mock.query("test_collection", query, limit)
weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection")
"test_collection"
)
weaviate_client_mock.client.collections.query.bm25.assert_called_with(
query=query, limit=limit)
query=query, limit=limit
)
def test_query_objects_without_limit(weaviate_client_mock):
@ -118,29 +130,33 @@ def test_query_objects_without_limit(weaviate_client_mock):
query = "name:John"
weaviate_client_mock.query("test_collection", query)
weaviate_client_mock.client.collections.get.assert_called_with(
"test_collection")
"test_collection"
)
weaviate_client_mock.client.collections.query.bm25.assert_called_with(
query=query, limit=10)
query=query, limit=10
)
def test_create_collection_failure(weaviate_client_mock):
# Test failure when creating a collection
with patch(
"weaviate_client.weaviate.collections.create",
side_effect=Exception("Create error"),
"weaviate_client.weaviate.collections.create",
side_effect=Exception("Create error"),
):
with pytest.raises(Exception, match="Error creating collection"):
weaviate_client_mock.create_collection("test_collection", [{
"name": "property"
}])
with pytest.raises(
Exception, match="Error creating collection"
):
weaviate_client_mock.create_collection(
"test_collection", [{"name": "property"}]
)
def test_add_object_failure(weaviate_client_mock):
# Test failure when adding an object
properties = {"name": "John"}
with patch(
"weaviate_client.weaviate.collections.data.insert",
side_effect=Exception("Insert error"),
"weaviate_client.weaviate.collections.data.insert",
side_effect=Exception("Insert error"),
):
with pytest.raises(Exception, match="Error adding object"):
weaviate_client_mock.add("test_collection", properties)
@ -150,8 +166,8 @@ def test_query_objects_failure(weaviate_client_mock):
# Test failure when querying objects
query = "name:John"
with patch(
"weaviate_client.weaviate.collections.query.bm25",
side_effect=Exception("Query error"),
"weaviate_client.weaviate.collections.query.bm25",
side_effect=Exception("Query error"),
):
with pytest.raises(Exception, match="Error querying objects"):
weaviate_client_mock.query("test_collection", query)
@ -162,20 +178,21 @@ def test_update_object_failure(weaviate_client_mock):
object_id = "12345"
properties = {"name": "Jane"}
with patch(
"weaviate_client.weaviate.collections.data.update",
side_effect=Exception("Update error"),
"weaviate_client.weaviate.collections.data.update",
side_effect=Exception("Update error"),
):
with pytest.raises(Exception, match="Error updating object"):
weaviate_client_mock.update("test_collection", object_id,
properties)
weaviate_client_mock.update(
"test_collection", object_id, properties
)
def test_delete_object_failure(weaviate_client_mock):
# Test failure when deleting an object
object_id = "12345"
with patch(
"weaviate_client.weaviate.collections.data.delete_by_id",
side_effect=Exception("Delete error"),
"weaviate_client.weaviate.collections.data.delete_by_id",
side_effect=Exception("Delete error"),
):
with pytest.raises(Exception, match="Error deleting object"):
weaviate_client_mock.delete("test_collection", object_id)

@ -8,11 +8,12 @@ from swarms.models.anthropic import Anthropic
# Mock the Anthropic API client for testing
class MockAnthropicClient:
def __init__(self, *args, **kwargs):
pass
def completions_create(self, prompt, stop_sequences, stream, **kwargs):
def completions_create(
self, prompt, stop_sequences, stream, **kwargs
):
return MockAnthropicResponse()
@ -45,7 +46,9 @@ def test_anthropic_init_default_values(anthropic_instance):
assert anthropic_instance.streaming is False
assert anthropic_instance.default_request_timeout == 600
assert (
anthropic_instance.anthropic_api_url == "https://test.anthropic.com")
anthropic_instance.anthropic_api_url
== "https://test.anthropic.com"
)
assert anthropic_instance.anthropic_api_key == "test_api_key"
@ -76,8 +79,9 @@ def test_anthropic_default_params(anthropic_instance):
}
def test_anthropic_run(mock_anthropic_env, mock_requests_post,
anthropic_instance):
def test_anthropic_run(
mock_anthropic_env, mock_requests_post, anthropic_instance
):
mock_response = Mock()
mock_response.json.return_value = {"completion": "Generated text"}
mock_requests_post.return_value = mock_response
@ -101,8 +105,9 @@ def test_anthropic_run(mock_anthropic_env, mock_requests_post,
)
def test_anthropic_call(mock_anthropic_env, mock_requests_post,
anthropic_instance):
def test_anthropic_call(
mock_anthropic_env, mock_requests_post, anthropic_instance
):
mock_response = Mock()
mock_response.json.return_value = {"completion": "Generated text"}
mock_requests_post.return_value = mock_response
@ -126,8 +131,9 @@ def test_anthropic_call(mock_anthropic_env, mock_requests_post,
)
def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post,
anthropic_instance):
def test_anthropic_exception_handling(
mock_anthropic_env, mock_requests_post, anthropic_instance
):
mock_response = Mock()
mock_response.json.return_value = {"error": "An error occurred"}
mock_requests_post.return_value = mock_response
@ -142,7 +148,6 @@ def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post,
class MockAnthropicResponse:
def __init__(self):
self.completion = "Mocked Response from Anthropic"
@ -168,7 +173,9 @@ def test_anthropic_async_call_method(anthropic_instance):
def test_anthropic_async_stream_method(anthropic_instance):
async_generator = anthropic_instance.async_stream("Translate to French.")
async_generator = anthropic_instance.async_stream(
"Translate to French."
)
for token in async_generator:
assert isinstance(token, str)
@ -192,51 +199,63 @@ def test_anthropic_wrap_prompt(anthropic_instance):
def test_anthropic_convert_prompt(anthropic_instance):
prompt = "What is the meaning of life?"
converted_prompt = anthropic_instance.convert_prompt(prompt)
assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT)
assert converted_prompt.startswith(
anthropic_instance.HUMAN_PROMPT
)
assert converted_prompt.endswith(anthropic_instance.AI_PROMPT)
def test_anthropic_call_with_stop(anthropic_instance):
response = anthropic_instance("Translate to French.",
stop=["stop1", "stop2"])
response = anthropic_instance(
"Translate to French.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Anthropic"
def test_anthropic_stream_with_stop(anthropic_instance):
generator = anthropic_instance.stream("Write a story.",
stop=["stop1", "stop2"])
generator = anthropic_instance.stream(
"Write a story.", stop=["stop1", "stop2"]
)
for token in generator:
assert isinstance(token, str)
def test_anthropic_async_call_with_stop(anthropic_instance):
response = anthropic_instance.async_call("Tell me a joke.",
stop=["stop1", "stop2"])
response = anthropic_instance.async_call(
"Tell me a joke.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Anthropic"
def test_anthropic_async_stream_with_stop(anthropic_instance):
async_generator = anthropic_instance.async_stream("Translate to French.",
stop=["stop1", "stop2"])
async_generator = anthropic_instance.async_stream(
"Translate to French.", stop=["stop1", "stop2"]
)
for token in async_generator:
assert isinstance(token, str)
def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance,):
def test_anthropic_get_num_tokens_with_count_tokens(
anthropic_instance,
):
anthropic_instance.count_tokens = Mock(return_value=10)
text = "This is a test sentence."
num_tokens = anthropic_instance.get_num_tokens(text)
assert num_tokens == 10
def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance,):
def test_anthropic_get_num_tokens_without_count_tokens(
anthropic_instance,
):
del anthropic_instance.count_tokens
with pytest.raises(NameError):
text = "This is a test sentence."
anthropic_instance.get_num_tokens(text)
def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance,):
def test_anthropic_wrap_prompt_without_human_ai_prompt(
anthropic_instance,
):
del anthropic_instance.HUMAN_PROMPT
del anthropic_instance.AI_PROMPT
prompt = "What is the meaning of life?"

@ -48,8 +48,10 @@ def test_cell_biology_response(biogpt_instance):
# 40. Test for a question about protein structure
def test_protein_structure_response(biogpt_instance):
question = ("What's the difference between alpha helix and beta sheet"
" structures in proteins?")
question = (
"What's the difference between alpha helix and beta sheet"
" structures in proteins?"
)
response = biogpt_instance(question)
assert response
assert isinstance(response, str)
@ -81,7 +83,9 @@ def test_bioinformatics_response(biogpt_instance):
# 44. Test for a neuroscience question
def test_neuroscience_response(biogpt_instance):
question = ("Explain the function of synapses in the nervous system.")
question = (
"Explain the function of synapses in the nervous system."
)
response = biogpt_instance(question)
assert response
assert isinstance(response, str)
@ -104,11 +108,8 @@ def test_init(bio_gpt):
def test_call(bio_gpt, monkeypatch):
def mock_pipeline(*args, **kwargs):
class MockGenerator:
def __call__(self, text, **kwargs):
return ["Generated text"]
@ -166,7 +167,9 @@ def test_get_config_return_type(biogpt_instance):
# 28. Test saving model functionality by checking if files are created
@patch.object(BioGptForCausalLM, "save_pretrained")
@patch.object(BioGptTokenizer, "save_pretrained")
def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance):
def test_save_model(
mock_save_model, mock_save_tokenizer, biogpt_instance
):
path = "test_path"
biogpt_instance.save_model(path)
mock_save_model.assert_called_once_with(path)
@ -176,7 +179,9 @@ def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance):
# 29. Test loading model from path
@patch.object(BioGptForCausalLM, "from_pretrained")
@patch.object(BioGptTokenizer, "from_pretrained")
def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance):
def test_load_from_path(
mock_load_model, mock_load_tokenizer, biogpt_instance
):
path = "test_path"
biogpt_instance.load_from_path(path)
mock_load_model.assert_called_once_with(path)
@ -193,7 +198,9 @@ def test_print_model_metadata(biogpt_instance):
# 31. Test that beam_search_decoding uses the correct number of beams
@patch.object(BioGptForCausalLM, "generate")
def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance):
def test_beam_search_decoding_num_beams(
mock_generate, biogpt_instance
):
biogpt_instance.beam_search_decoding("test_sentence", num_beams=7)
_, kwargs = mock_generate.call_args
assert kwargs["num_beams"] == 7
@ -201,8 +208,12 @@ def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance):
# 32. Test if beam_search_decoding handles early_stopping
@patch.object(BioGptForCausalLM, "generate")
def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance):
biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False)
def test_beam_search_decoding_early_stopping(
mock_generate, biogpt_instance
):
biogpt_instance.beam_search_decoding(
"test_sentence", early_stopping=False
)
_, kwargs = mock_generate.call_args
assert kwargs["early_stopping"] is False

@ -42,7 +42,9 @@ def test_cohere_async_api_error_handling(cohere_instance):
cohere_instance.model = "base"
cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception):
cohere_instance.async_call("Error handling with invalid API key.")
cohere_instance.async_call(
"Error handling with invalid API key."
)
def test_cohere_stream_api_error_handling(cohere_instance):
@ -51,7 +53,8 @@ def test_cohere_stream_api_error_handling(cohere_instance):
cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception):
generator = cohere_instance.stream(
"Error handling with invalid API key.")
"Error handling with invalid API key."
)
for token in generator:
pass
@ -91,26 +94,31 @@ def test_cohere_convert_prompt(cohere_instance):
def test_cohere_call_with_stop(cohere_instance):
response = cohere_instance("Translate to French.", stop=["stop1", "stop2"])
response = cohere_instance(
"Translate to French.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Cohere"
def test_cohere_stream_with_stop(cohere_instance):
generator = cohere_instance.stream("Write a story.",
stop=["stop1", "stop2"])
generator = cohere_instance.stream(
"Write a story.", stop=["stop1", "stop2"]
)
for token in generator:
assert isinstance(token, str)
def test_cohere_async_call_with_stop(cohere_instance):
response = cohere_instance.async_call("Tell me a joke.",
stop=["stop1", "stop2"])
response = cohere_instance.async_call(
"Tell me a joke.", stop=["stop1", "stop2"]
)
assert response == "Mocked Response from Cohere"
def test_cohere_async_stream_with_stop(cohere_instance):
async_generator = cohere_instance.async_stream("Translate to French.",
stop=["stop1", "stop2"])
async_generator = cohere_instance.async_stream(
"Translate to French.", stop=["stop1", "stop2"]
)
for token in async_generator:
assert isinstance(token, str)
@ -166,8 +174,12 @@ def test_base_cohere_validate_environment_without_cohere():
# Test cases for benchmarking generations with various models
def test_cohere_generate_with_command_light(cohere_instance):
cohere_instance.model = "command-light"
response = cohere_instance("Generate text with Command Light model.")
assert response.startswith("Generated text with Command Light model")
response = cohere_instance(
"Generate text with Command Light model."
)
assert response.startswith(
"Generated text with Command Light model"
)
def test_cohere_generate_with_command(cohere_instance):
@ -190,54 +202,74 @@ def test_cohere_generate_with_base(cohere_instance):
def test_cohere_generate_with_embed_english_v2(cohere_instance):
cohere_instance.model = "embed-english-v2.0"
response = cohere_instance("Generate embeddings with English v2.0 model.")
assert response.startswith("Generated embeddings with English v2.0 model")
response = cohere_instance(
"Generate embeddings with English v2.0 model."
)
assert response.startswith(
"Generated embeddings with English v2.0 model"
)
def test_cohere_generate_with_embed_english_light_v2(cohere_instance):
cohere_instance.model = "embed-english-light-v2.0"
response = cohere_instance(
"Generate embeddings with English Light v2.0 model.")
"Generate embeddings with English Light v2.0 model."
)
assert response.startswith(
"Generated embeddings with English Light v2.0 model")
"Generated embeddings with English Light v2.0 model"
)
def test_cohere_generate_with_embed_multilingual_v2(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance(
"Generate embeddings with Multilingual v2.0 model.")
"Generate embeddings with Multilingual v2.0 model."
)
assert response.startswith(
"Generated embeddings with Multilingual v2.0 model")
"Generated embeddings with Multilingual v2.0 model"
)
def test_cohere_generate_with_embed_english_v3(cohere_instance):
cohere_instance.model = "embed-english-v3.0"
response = cohere_instance("Generate embeddings with English v3.0 model.")
assert response.startswith("Generated embeddings with English v3.0 model")
response = cohere_instance(
"Generate embeddings with English v3.0 model."
)
assert response.startswith(
"Generated embeddings with English v3.0 model"
)
def test_cohere_generate_with_embed_english_light_v3(cohere_instance):
cohere_instance.model = "embed-english-light-v3.0"
response = cohere_instance(
"Generate embeddings with English Light v3.0 model.")
"Generate embeddings with English Light v3.0 model."
)
assert response.startswith(
"Generated embeddings with English Light v3.0 model")
"Generated embeddings with English Light v3.0 model"
)
def test_cohere_generate_with_embed_multilingual_v3(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance(
"Generate embeddings with Multilingual v3.0 model.")
"Generate embeddings with Multilingual v3.0 model."
)
assert response.startswith(
"Generated embeddings with Multilingual v3.0 model")
"Generated embeddings with Multilingual v3.0 model"
)
def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance,):
def test_cohere_generate_with_embed_multilingual_light_v3(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-light-v3.0"
response = cohere_instance(
"Generate embeddings with Multilingual Light v3.0 model.")
"Generate embeddings with Multilingual Light v3.0 model."
)
assert response.startswith(
"Generated embeddings with Multilingual Light v3.0 model")
"Generated embeddings with Multilingual Light v3.0 model"
)
# Add more test cases to benchmark other models and functionalities
@ -267,13 +299,17 @@ def test_cohere_call_with_embed_english_v3_model(cohere_instance):
assert isinstance(response, str)
def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance,):
def test_cohere_call_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance,):
def test_cohere_call_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
@ -293,7 +329,9 @@ def test_cohere_call_with_long_prompt(cohere_instance):
def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance):
cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit.")
prompt = (
"This is a test prompt that will exceed the max tokens limit."
)
with pytest.raises(ValueError):
cohere_instance(prompt)
@ -326,14 +364,18 @@ def test_cohere_stream_with_embed_english_v3_model(cohere_instance):
assert isinstance(token, str)
def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance,):
def test_cohere_stream_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance,):
def test_cohere_stream_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0"
generator = cohere_instance.stream("Write a story.")
for token in generator:
@ -352,25 +394,33 @@ def test_cohere_async_call_with_base_model(cohere_instance):
assert isinstance(response, str)
def test_cohere_async_call_with_embed_english_v2_model(cohere_instance,):
def test_cohere_async_call_with_embed_english_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v2.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_english_v3_model(cohere_instance,):
def test_cohere_async_call_with_embed_english_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v3.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance,):
def test_cohere_async_call_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance,):
def test_cohere_async_call_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
@ -390,28 +440,36 @@ def test_cohere_async_stream_with_base_model(cohere_instance):
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance,):
def test_cohere_async_stream_with_embed_english_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v2.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance,):
def test_cohere_async_stream_with_embed_english_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-english-v3.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance,):
def test_cohere_async_stream_with_embed_multilingual_v2_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v2.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,):
def test_cohere_async_stream_with_embed_multilingual_v3_model(
cohere_instance,
):
cohere_instance.model = "embed-multilingual-v3.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
@ -421,7 +479,9 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance,):
def test_cohere_representation_model_embedding(cohere_instance):
# Test using the Representation model for text embedding
cohere_instance.model = "embed-english-v3.0"
embedding = cohere_instance.embed("Generate an embedding for this text.")
embedding = cohere_instance.embed(
"Generate an embedding for this text."
)
assert isinstance(embedding, list)
assert len(embedding) > 0
@ -435,20 +495,26 @@ def test_cohere_representation_model_classification(cohere_instance):
assert "score" in classification
def test_cohere_representation_model_language_detection(cohere_instance,):
def test_cohere_representation_model_language_detection(
cohere_instance,
):
# Test using the Representation model for language detection
cohere_instance.model = "embed-english-v3.0"
language = cohere_instance.detect_language(
"Detect the language of this text.")
"Detect the language of this text."
)
assert isinstance(language, str)
def test_cohere_representation_model_max_tokens_limit_exceeded(
cohere_instance,):
cohere_instance,
):
# Test handling max tokens limit exceeded error
cohere_instance.model = "embed-english-v3.0"
cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit.")
prompt = (
"This is a test prompt that will exceed the max tokens limit."
)
with pytest.raises(ValueError):
cohere_instance.embed(prompt)
@ -456,80 +522,102 @@ def test_cohere_representation_model_max_tokens_limit_exceeded(
# Add more production-grade test cases based on real-world scenarios
def test_cohere_representation_model_multilingual_embedding(cohere_instance,):
def test_cohere_representation_model_multilingual_embedding(
cohere_instance,
):
# Test using the Representation model for multilingual text embedding
cohere_instance.model = "embed-multilingual-v3.0"
embedding = cohere_instance.embed("Generate multilingual embeddings.")
embedding = cohere_instance.embed(
"Generate multilingual embeddings."
)
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_multilingual_classification(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for multilingual text classification
cohere_instance.model = "embed-multilingual-v3.0"
classification = cohere_instance.classify("Classify multilingual text.")
classification = cohere_instance.classify(
"Classify multilingual text."
)
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_multilingual_language_detection(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for multilingual language detection
cohere_instance.model = "embed-multilingual-v3.0"
language = cohere_instance.detect_language(
"Detect the language of multilingual text.")
"Detect the language of multilingual text."
)
assert isinstance(language, str)
def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded(
cohere_instance,):
cohere_instance,
):
# Test handling max tokens limit exceeded error for multilingual model
cohere_instance.model = "embed-multilingual-v3.0"
cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit"
" for multilingual model.")
prompt = (
"This is a test prompt that will exceed the max tokens limit"
" for multilingual model."
)
with pytest.raises(ValueError):
cohere_instance.embed(prompt)
def test_cohere_representation_model_multilingual_light_embedding(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for multilingual light text embedding
cohere_instance.model = "embed-multilingual-light-v3.0"
embedding = cohere_instance.embed("Generate multilingual light embeddings.")
embedding = cohere_instance.embed(
"Generate multilingual light embeddings."
)
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_multilingual_light_classification(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for multilingual light text classification
cohere_instance.model = "embed-multilingual-light-v3.0"
classification = cohere_instance.classify(
"Classify multilingual light text.")
"Classify multilingual light text."
)
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_multilingual_light_language_detection(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for multilingual light language detection
cohere_instance.model = "embed-multilingual-light-v3.0"
language = cohere_instance.detect_language(
"Detect the language of multilingual light text.")
"Detect the language of multilingual light text."
)
assert isinstance(language, str)
def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded(
cohere_instance,):
cohere_instance,
):
# Test handling max tokens limit exceeded error for multilingual light model
cohere_instance.model = "embed-multilingual-light-v3.0"
cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit"
" for multilingual light model.")
prompt = (
"This is a test prompt that will exceed the max tokens limit"
" for multilingual light model."
)
with pytest.raises(ValueError):
cohere_instance.embed(prompt)
@ -537,14 +625,18 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede
def test_cohere_command_light_model(cohere_instance):
# Test using the Command Light model for text generation
cohere_instance.model = "command-light"
response = cohere_instance("Generate text using Command Light model.")
response = cohere_instance(
"Generate text using Command Light model."
)
assert isinstance(response, str)
def test_cohere_base_light_model(cohere_instance):
# Test using the Base Light model for text generation
cohere_instance.model = "base-light"
response = cohere_instance("Generate text using Base Light model.")
response = cohere_instance(
"Generate text using Base Light model."
)
assert isinstance(response, str)
@ -555,7 +647,9 @@ def test_cohere_generate_summarize_endpoint(cohere_instance):
assert isinstance(response, str)
def test_cohere_representation_model_english_embedding(cohere_instance,):
def test_cohere_representation_model_english_embedding(
cohere_instance,
):
# Test using the Representation model for English text embedding
cohere_instance.model = "embed-english-v3.0"
embedding = cohere_instance.embed("Generate English embeddings.")
@ -563,69 +657,90 @@ def test_cohere_representation_model_english_embedding(cohere_instance,):
assert len(embedding) > 0
def test_cohere_representation_model_english_classification(cohere_instance,):
def test_cohere_representation_model_english_classification(
cohere_instance,
):
# Test using the Representation model for English text classification
cohere_instance.model = "embed-english-v3.0"
classification = cohere_instance.classify("Classify English text.")
classification = cohere_instance.classify(
"Classify English text."
)
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_english_language_detection(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for English language detection
cohere_instance.model = "embed-english-v3.0"
language = cohere_instance.detect_language(
"Detect the language of English text.")
"Detect the language of English text."
)
assert isinstance(language, str)
def test_cohere_representation_model_english_max_tokens_limit_exceeded(
cohere_instance,):
cohere_instance,
):
# Test handling max tokens limit exceeded error for English model
cohere_instance.model = "embed-english-v3.0"
cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit"
" for English model.")
prompt = (
"This is a test prompt that will exceed the max tokens limit"
" for English model."
)
with pytest.raises(ValueError):
cohere_instance.embed(prompt)
def test_cohere_representation_model_english_light_embedding(cohere_instance,):
def test_cohere_representation_model_english_light_embedding(
cohere_instance,
):
# Test using the Representation model for English light text embedding
cohere_instance.model = "embed-english-light-v3.0"
embedding = cohere_instance.embed("Generate English light embeddings.")
embedding = cohere_instance.embed(
"Generate English light embeddings."
)
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_english_light_classification(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for English light text classification
cohere_instance.model = "embed-english-light-v3.0"
classification = cohere_instance.classify("Classify English light text.")
classification = cohere_instance.classify(
"Classify English light text."
)
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_english_light_language_detection(
cohere_instance,):
cohere_instance,
):
# Test using the Representation model for English light language detection
cohere_instance.model = "embed-english-light-v3.0"
language = cohere_instance.detect_language(
"Detect the language of English light text.")
"Detect the language of English light text."
)
assert isinstance(language, str)
def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
cohere_instance,):
cohere_instance,
):
# Test handling max tokens limit exceeded error for English light model
cohere_instance.model = "embed-english-light-v3.0"
cohere_instance.max_tokens = 10
prompt = ("This is a test prompt that will exceed the max tokens limit"
" for English light model.")
prompt = (
"This is a test prompt that will exceed the max tokens limit"
" for English light model."
)
with pytest.raises(ValueError):
cohere_instance.embed(prompt)
@ -633,7 +748,9 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
def test_cohere_command_model(cohere_instance):
# Test using the Command model for text generation
cohere_instance.model = "command"
response = cohere_instance("Generate text using the Command model.")
response = cohere_instance(
"Generate text using the Command model."
)
assert isinstance(response, str)
@ -647,7 +764,9 @@ def test_cohere_invalid_model(cohere_instance):
cohere_instance("Generate text using an invalid model.")
def test_cohere_base_model_generation_with_max_tokens(cohere_instance,):
def test_cohere_base_model_generation_with_max_tokens(
cohere_instance,
):
# Test generating text using the base model with a specified max_tokens limit
cohere_instance.model = "base"
cohere_instance.max_tokens = 20

@ -30,30 +30,45 @@ def test_run_text_to_speech(eleven_labs_tool):
def test_play_speech(eleven_labs_tool):
with patch("builtins.open", mock_open(read_data="fake_audio_data")):
with patch(
"builtins.open", mock_open(read_data="fake_audio_data")
):
eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
def test_stream_speech(eleven_labs_tool):
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file:
with patch(
"tempfile.NamedTemporaryFile", mock_open()
) as mock_file:
eleven_labs_tool.stream_speech(SAMPLE_TEXT)
mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False)
mock_file.assert_called_with(
mode="bx", suffix=".wav", delete=False
)
# Testing fixture and environment variables
def test_api_key_validation(eleven_labs_tool):
with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY):
with patch(
"langchain.utils.get_from_dict_or_env", return_value=API_KEY
):
values = {"eleven_api_key": None}
validated_values = eleven_labs_tool.validate_environment(values)
validated_values = eleven_labs_tool.validate_environment(
values
)
assert "eleven_api_key" in validated_values
# Mocking the external library
def test_run_text_to_speech_with_mock(eleven_labs_tool):
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch(
"your_module._import_elevenlabs") as mock_elevenlabs:
with patch(
"tempfile.NamedTemporaryFile", mock_open()
) as mock_file, patch(
"your_module._import_elevenlabs"
) as mock_elevenlabs:
mock_elevenlabs_instance = mock_elevenlabs.return_value
mock_elevenlabs_instance.generate.return_value = (b"fake_audio_data")
mock_elevenlabs_instance.generate.return_value = (
b"fake_audio_data"
)
eleven_labs_tool.run(SAMPLE_TEXT)
assert mock_file.call_args[1]["suffix"] == ".wav"
assert mock_file.call_args[1]["delete"] is False
@ -65,11 +80,14 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool):
with patch("your_module._import_elevenlabs") as mock_elevenlabs:
mock_elevenlabs_instance = mock_elevenlabs.return_value
mock_elevenlabs_instance.generate.side_effect = Exception(
"Test Exception")
"Test Exception"
)
with pytest.raises(
RuntimeError,
match=("Error while running ElevenLabsText2SpeechTool: Test"
" Exception"),
RuntimeError,
match=(
"Error while running ElevenLabsText2SpeechTool: Test"
" Exception"
),
):
eleven_labs_tool.run(SAMPLE_TEXT)
@ -79,7 +97,9 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool):
"model",
[ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL],
)
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model):
def test_run_text_to_speech_with_different_models(
eleven_labs_tool, model
):
eleven_labs_tool.model = model
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
assert isinstance(speech_file, str)

@ -39,4 +39,6 @@ def test_fire_function_caller_run(mocker):
tokenizer.batch_decode.assert_called_once_with(generated_ids)
# Assert the decoded output is printed
assert decoded_output in mocker.patch.object(print, "call_args_list")
assert decoded_output in mocker.patch.object(
print, "call_args_list"
)

@ -38,7 +38,9 @@ def fuyu_instance():
# Test using the fixture.
def test_fuyu_processor_initialization(fuyu_instance):
assert isinstance(fuyu_instance.processor, FuyuProcessor)
assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor)
assert isinstance(
fuyu_instance.image_processor, FuyuImageProcessor
)
# Test exception when providing an invalid image path.
@ -49,7 +51,6 @@ def test_invalid_image_path(fuyu_instance):
# Using monkeypatch to replace the Image.open method to simulate a failure.
def test_image_open_failure(fuyu_instance, monkeypatch):
def mock_open(*args, **kwargs):
raise Exception("Mocked failure")
@ -78,9 +79,13 @@ def test_tokenizer_type(fuyu_instance):
def test_processor_has_image_processor_and_tokenizer(fuyu_instance):
assert (fuyu_instance.processor.image_processor ==
fuyu_instance.image_processor)
assert (fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer)
assert (
fuyu_instance.processor.image_processor
== fuyu_instance.image_processor
)
assert (
fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer
)
def test_model_device_map(fuyu_instance):
@ -139,14 +144,22 @@ def test_get_img_invalid_path(fuyu_instance):
# Test `run` method with valid inputs
def test_run_valid_inputs(fuyu_instance):
with patch.object(fuyu_instance, "get_img") as mock_get_img, patch.object(
fuyu_instance, "processor") as mock_processor, patch.object(
fuyu_instance, "model") as mock_model:
with patch.object(
fuyu_instance, "get_img"
) as mock_get_img, patch.object(
fuyu_instance, "processor"
) as mock_processor, patch.object(
fuyu_instance, "model"
) as mock_model:
mock_get_img.return_value = "Test image"
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test text"]
result = fuyu_instance.run("Hello, world!", "valid/path/to/image.png")
result = fuyu_instance.run(
"Hello, world!", "valid/path/to/image.png"
)
assert result == ["Test text"]
@ -173,7 +186,9 @@ def test_run_invalid_image_path(fuyu_instance):
with patch.object(fuyu_instance, "get_img") as mock_get_img:
mock_get_img.side_effect = FileNotFoundError
with pytest.raises(FileNotFoundError):
fuyu_instance.run("Hello, world!", "invalid/path/to/image.png")
fuyu_instance.run(
"Hello, world!", "invalid/path/to/image.png"
)
# Test `__init__` method with default parameters

@ -24,8 +24,12 @@ def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model):
assert model.model is mock_genai_model
def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model):
model = Gemini(model_name="custom-model", gemini_api_key="custom-api-key")
def test_gemini_init_custom_params(
mock_gemini_api_key, mock_genai_model
):
model = Gemini(
model_name="custom-model", gemini_api_key="custom-api-key"
)
assert model.model_name == "custom-model"
assert model.gemini_api_key == "custom-api-key"
assert model.model is mock_genai_model
@ -50,13 +54,16 @@ def test_gemini_run_with_img(
response = model.run(task=task, img=img)
assert response == "Generated response"
mock_generate_content.assert_called_with(content=[task, "Processed image"])
mock_generate_content.assert_called_with(
content=[task, "Processed image"]
)
mock_process_img.assert_called_with(img=img)
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key,
mock_genai_model):
def test_gemini_run_without_img(
mock_generate_content, mock_gemini_api_key, mock_genai_model
):
model = Gemini()
task = "A cat"
response_mock = Mock(text="Generated response")
@ -69,8 +76,9 @@ def test_gemini_run_without_img(mock_generate_content, mock_gemini_api_key,
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_exception(mock_generate_content, mock_gemini_api_key,
mock_genai_model):
def test_gemini_run_exception(
mock_generate_content, mock_gemini_api_key, mock_genai_model
):
model = Gemini()
task = "A cat"
mock_generate_content.side_effect = Exception("Test exception")
@ -88,23 +96,30 @@ def test_gemini_process_img(mock_gemini_api_key, mock_genai_model):
with patch("builtins.open", create=True) as open_mock:
open_mock.return_value.__enter__.return_value.read.return_value = (
img_data)
img_data
)
processed_img = model.process_img(img)
assert processed_img == [{"mime_type": "image/png", "data": img_data}]
assert processed_img == [
{"mime_type": "image/png", "data": img_data}
]
open_mock.assert_called_with(img, "rb")
# Test Gemini initialization with missing API key
def test_gemini_init_missing_api_key():
with pytest.raises(ValueError, match="Please provide a Gemini API key"):
with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
Gemini(gemini_api_key=None)
# Test Gemini initialization with missing model name
def test_gemini_init_missing_model_name():
with pytest.raises(ValueError, match="Please provide a model name"):
with pytest.raises(
ValueError, match="Please provide a model name"
):
Gemini(model_name=None)
@ -126,20 +141,26 @@ def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model):
# Test Gemini process_img method with missing image
def test_gemini_process_img_missing_image(mock_gemini_api_key,
mock_genai_model):
def test_gemini_process_img_missing_image(
mock_gemini_api_key, mock_genai_model
):
model = Gemini()
img = None
with pytest.raises(ValueError, match="Please provide an image to process"):
with pytest.raises(
ValueError, match="Please provide an image to process"
):
model.process_img(img=img)
# Test Gemini process_img method with missing image type
def test_gemini_process_img_missing_image_type(mock_gemini_api_key,
mock_genai_model):
def test_gemini_process_img_missing_image_type(
mock_gemini_api_key, mock_genai_model
):
model = Gemini()
img = "cat.png"
with pytest.raises(ValueError, match="Please provide the image type"):
with pytest.raises(
ValueError, match="Please provide the image type"
):
model.process_img(img=img, type=None)
@ -147,7 +168,9 @@ def test_gemini_process_img_missing_image_type(mock_gemini_api_key,
def test_gemini_process_img_missing_api_key(mock_genai_model):
model = Gemini(gemini_api_key=None)
img = "cat.png"
with pytest.raises(ValueError, match="Please provide a Gemini API key"):
with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
model.process_img(img=img, type="image/png")
@ -170,7 +193,9 @@ def test_gemini_run_mock_img_processing(
response = model.run(task=task, img=img)
assert response == "Generated response"
mock_generate_content.assert_called_with(content=[task, "Processed image"])
mock_generate_content.assert_called_with(
content=[task, "Processed image"]
)
mock_process_img.assert_called_with(img=img)

@ -11,13 +11,16 @@ except ImportError:
@pytest.fixture
def api():
return Gigabind(host="localhost", port=8000, endpoint="embeddings")
return Gigabind(
host="localhost", port=8000, endpoint="embeddings"
)
@pytest.fixture
def mock(requests_mock):
requests_mock.post("http://localhost:8000/embeddings",
json={"result": "success"})
requests_mock.post(
"http://localhost:8000/embeddings", json={"result": "success"}
)
return requests_mock
@ -37,9 +40,9 @@ def test_run_with_audio(api, mock):
def test_run_with_all(api, mock):
response = api.run(text="Hello, world!",
vision="image.jpg",
audio="audio.mp3")
response = api.run(
text="Hello, world!", vision="image.jpg", audio="audio.mp3"
)
assert response == {"result": "success"}
@ -62,20 +65,9 @@ def test_retry_on_failure(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
[
{
"status_code": 500,
"json": {}
},
{
"status_code": 500,
"json": {}
},
{
"status_code": 200,
"json": {
"result": "success"
}
},
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
{"status_code": 200, "json": {"result": "success"}},
],
)
response = api.run(text="Hello, world!")
@ -86,18 +78,9 @@ def test_retry_exhausted(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
[
{
"status_code": 500,
"json": {}
},
{
"status_code": 500,
"json": {}
},
{
"status_code": 500,
"json": {}
},
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
],
)
response = api.run(text="Hello, world!")
@ -110,7 +93,9 @@ def test_proxy_url(api):
def test_invalid_response(api, requests_mock):
requests_mock.post("http://localhost:8000/embeddings", text="not json")
requests_mock.post(
"http://localhost:8000/embeddings", text="not json"
)
response = api.run(text="Hello, world!")
assert response is None
@ -125,7 +110,9 @@ def test_connection_error(api, requests_mock):
def test_http_error(api, requests_mock):
requests_mock.post("http://localhost:8000/embeddings", status_code=500)
requests_mock.post(
"http://localhost:8000/embeddings", status_code=500
)
response = api.run(text="Hello, world!")
assert response is None
@ -161,7 +148,9 @@ def test_run_with_large_all(api, mock):
large_text = "Hello, world! " * 10000 # 10,000 repetitions
large_vision = "image.jpg" * 10000 # 10,000 repetitions
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
response = api.run(text=large_text, vision=large_vision, audio=large_audio)
response = api.run(
text=large_text, vision=large_vision, audio=large_audio
)
assert response == {"result": "success"}

@ -26,9 +26,9 @@ def test_init(vision_api):
def test_encode_image(vision_api):
with patch(
"builtins.open",
mock_open(read_data=b"test_image_data"),
create=True,
"builtins.open",
mock_open(read_data=b"test_image_data"),
create=True,
):
encoded_image = vision_api.encode_image(img)
assert encoded_image == "dGVzdF9pbWFnZV9kYXRh"
@ -37,8 +37,8 @@ def test_encode_image(vision_api):
def test_run_success(vision_api):
expected_response = {"This is the model's response."}
with patch(
"requests.post",
return_value=Mock(json=lambda: expected_response),
"requests.post",
return_value=Mock(json=lambda: expected_response),
) as mock_post:
result = vision_api.run("What is this?", img)
mock_post.assert_called_once()
@ -46,7 +46,9 @@ def test_run_success(vision_api):
def test_run_request_error(vision_api):
with patch("requests.post", side_effect=RequestException("Request Error")):
with patch(
"requests.post", side_effect=RequestException("Request Error")
):
with pytest.raises(RequestException):
vision_api.run("What is this?", img)
@ -54,18 +56,20 @@ def test_run_request_error(vision_api):
def test_run_response_error(vision_api):
expected_response = {"error": "Model Error"}
with patch(
"requests.post",
return_value=Mock(json=lambda: expected_response),
"requests.post",
return_value=Mock(json=lambda: expected_response),
):
with pytest.raises(RuntimeError):
vision_api.run("What is this?", img)
def test_call(vision_api):
expected_response = {"choices": [{"text": "This is the model's response."}]}
expected_response = {
"choices": [{"text": "This is the model's response."}]
}
with patch(
"requests.post",
return_value=Mock(json=lambda: expected_response),
"requests.post",
return_value=Mock(json=lambda: expected_response),
) as mock_post:
result = vision_api("What is this?", img)
mock_post.assert_called_once()
@ -91,7 +95,9 @@ def test_initialization_with_custom_key():
def test_run_with_exception(gpt_api):
task = "What is in the image?"
img_url = img
with patch("requests.post", side_effect=Exception("Test Exception")):
with patch(
"requests.post", side_effect=Exception("Test Exception")
):
with pytest.raises(Exception):
gpt_api.run(task, img_url)
@ -99,10 +105,14 @@ def test_run_with_exception(gpt_api):
def test_call_method_successful_response(gpt_api):
task = "What is in the image?"
img_url = img
response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]}
response_json = {
"choices": [{"text": "Answer from GPT-4 Vision"}]
}
mock_response = Mock()
mock_response.json.return_value = response_json
with patch("requests.post", return_value=mock_response) as mock_post:
with patch(
"requests.post", return_value=mock_response
) as mock_post:
result = gpt_api(task, img_url)
mock_post.assert_called_once()
assert result == response_json
@ -111,7 +121,9 @@ def test_call_method_successful_response(gpt_api):
def test_call_method_with_exception(gpt_api):
task = "What is in the image?"
img_url = img
with patch("requests.post", side_effect=Exception("Test Exception")):
with patch(
"requests.post", side_effect=Exception("Test Exception")
):
with pytest.raises(Exception):
gpt_api(task, img_url)
@ -119,17 +131,16 @@ def test_call_method_with_exception(gpt_api):
@pytest.mark.asyncio
async def test_arun_success(vision_api):
expected_response = {
"choices": [{
"message": {
"content": "This is the model's response."
}
}]
"choices": [
{"message": {"content": "This is the model's response."}}
]
}
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock(
return_value=expected_response)),
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(
json=AsyncMock(return_value=expected_response)
),
) as mock_post:
result = await vision_api.arun("What is this?", img)
mock_post.assert_called_once()
@ -139,9 +150,9 @@ async def test_arun_success(vision_api):
@pytest.mark.asyncio
async def test_arun_request_error(vision_api):
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=Exception("Request Error"),
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=Exception("Request Error"),
):
with pytest.raises(Exception):
await vision_api.arun("What is this?", img)
@ -149,15 +160,13 @@ async def test_arun_request_error(vision_api):
def test_run_many_success(vision_api):
expected_response = {
"choices": [{
"message": {
"content": "This is the model's response."
}
}]
"choices": [
{"message": {"content": "This is the model's response."}}
]
}
with patch(
"requests.post",
return_value=Mock(json=lambda: expected_response),
"requests.post",
return_value=Mock(json=lambda: expected_response),
) as mock_post:
tasks = ["What is this?", "What is that?"]
imgs = [img, img]
@ -170,7 +179,9 @@ def test_run_many_success(vision_api):
def test_run_many_request_error(vision_api):
with patch("requests.post", side_effect=RequestException("Request Error")):
with patch(
"requests.post", side_effect=RequestException("Request Error")
):
tasks = ["What is this?", "What is that?"]
imgs = [img, img]
with pytest.raises(RequestException):
@ -180,9 +191,11 @@ def test_run_many_request_error(vision_api):
@pytest.mark.asyncio
async def test_arun_json_decode_error(vision_api):
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock(side_effect=ValueError)),
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(
json=AsyncMock(side_effect=ValueError)
),
):
with pytest.raises(ValueError):
await vision_api.arun("What is this?", img)
@ -192,9 +205,11 @@ async def test_arun_json_decode_error(vision_api):
async def test_arun_api_error(vision_api):
error_response = {"error": {"message": "API Error"}}
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock(return_value=error_response)),
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(
json=AsyncMock(return_value=error_response)
),
):
with pytest.raises(Exception, match="API Error"):
await vision_api.arun("What is this?", img)
@ -204,10 +219,11 @@ async def test_arun_api_error(vision_api):
async def test_arun_unexpected_response(vision_api):
unexpected_response = {"unexpected": "response"}
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(json=AsyncMock(
return_value=unexpected_response)),
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
return_value=AsyncMock(
json=AsyncMock(return_value=unexpected_response)
),
):
with pytest.raises(Exception, match="Unexpected response"):
await vision_api.arun("What is this?", img)
@ -216,9 +232,9 @@ async def test_arun_unexpected_response(vision_api):
@pytest.mark.asyncio
async def test_arun_retries(vision_api):
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=ClientResponseError(None, None),
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=ClientResponseError(None, None),
) as mock_post:
with pytest.raises(ClientResponseError):
await vision_api.arun("What is this?", img)
@ -228,9 +244,9 @@ async def test_arun_retries(vision_api):
@pytest.mark.asyncio
async def test_arun_timeout(vision_api):
with patch(
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
):
with pytest.raises(asyncio.TimeoutError):
await vision_api.arun("What is this?", img)

@ -133,7 +133,9 @@ def test_llm_set_repitition_penalty(llm_instance):
def test_llm_set_no_repeat_ngram_size(llm_instance):
new_no_repeat_ngram_size = 6
llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size)
assert (llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size)
assert (
llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size
)
# Test for setting temperature
@ -183,7 +185,9 @@ def test_llm_set_model_id(llm_instance):
# Test for setting model
@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained")
@patch(
"swarms.models.huggingface.AutoModelForCausalLM.from_pretrained"
)
def test_llm_set_model(mock_model, llm_instance):
mock_model.return_value = "mocked model"
llm_instance.set_model(mock_model)

@ -14,14 +14,19 @@ def mock_pipeline():
@pytest.fixture
def pipeline(mock_pipeline):
return HuggingfacePipeline("text-generation",
"meta-llama/Llama-2-13b-chat-hf")
return HuggingfacePipeline(
"text-generation", "meta-llama/Llama-2-13b-chat-hf"
)
def test_init(pipeline, mock_pipeline):
assert pipeline.task_type == "text-generation"
assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf"
assert (pipeline.use_fp8 is True if torch.cuda.is_available() else False)
assert (
pipeline.use_fp8 is True
if torch.cuda.is_available()
else False
)
mock_pipeline.assert_called_once_with(
"text-generation",
"meta-llama/Llama-2-13b-chat-hf",
@ -46,5 +51,6 @@ def test_run_with_different_task(pipeline, mock_pipeline):
mock_pipeline.return_value = "Generated text"
result = pipeline.run("text-classification", "Hello, world!")
assert result == "Generated text"
mock_pipeline.assert_called_once_with("text-classification",
"Hello, world!")
mock_pipeline.assert_called_once_with(
"text-classification", "Hello, world!"
)

@ -18,7 +18,10 @@ def llm_instance():
# Test for instantiation and attributes
def test_llm_initialization(llm_instance):
assert (llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha")
assert (
llm_instance.model_id
== "NousResearch/Nous-Hermes-2-Vision-Alpha"
)
assert llm_instance.max_length == 500
# ... add more assertions for all default attributes
@ -85,12 +88,15 @@ def test_llm_memory_consumption(llm_instance):
)
def test_llm_initialization_params(model_id, max_length):
if max_length:
instance = HuggingfaceLLM(model_id=model_id, max_length=max_length)
instance = HuggingfaceLLM(
model_id=model_id, max_length=max_length
)
assert instance.max_length == max_length
else:
instance = HuggingfaceLLM(model_id=model_id)
assert (instance.max_length == 500
) # Assuming 500 is the default max_length
assert (
instance.max_length == 500
) # Assuming 500 is the default max_length
# Test for setting an invalid device
@ -138,7 +144,9 @@ def test_llm_run_output_length(mock_run, llm_instance):
# Test the tokenizer handling special tokens correctly
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode")
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode")
def test_llm_tokenizer_special_tokens(mock_decode, mock_encode, llm_instance):
def test_llm_tokenizer_special_tokens(
mock_decode, mock_encode, llm_instance
):
mock_encode.return_value = "encoded input with special tokens"
mock_decode.return_value = "decoded output with special tokens"
result = llm_instance.run("test task with special tokens")
@ -164,8 +172,9 @@ def test_llm_response_time(mock_run, llm_instance):
start_time = time.time()
llm_instance.run("test task for response time")
end_time = time.time()
assert (end_time - start_time
< 1) # Assuming the response should be faster than 1 second
assert (
end_time - start_time < 1
) # Assuming the response should be faster than 1 second
# Test the logging of a warning for long inputs
@ -188,10 +197,13 @@ def test_llm_run_model_exception(mock_generate, llm_instance):
# Test the behavior when GPU is forced but not available
@patch("torch.cuda.is_available", return_value=False)
def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance):
def test_llm_force_gpu_when_unavailable(
mock_is_available, llm_instance
):
with pytest.raises(EnvironmentError):
llm_instance.set_device(
"cuda") # Attempt to set CUDA when it's not available
"cuda"
) # Attempt to set CUDA when it's not available
# Test for proper cleanup after model use (releasing resources)
@ -209,8 +221,9 @@ def test_llm_multilingual_input(mock_run, llm_instance):
mock_run.return_value = "mocked multilingual output"
multilingual_input = "Bonjour, ceci est un test multilingue."
result = llm_instance.run(multilingual_input)
assert isinstance(result,
str) # Simple check to ensure output is string type
assert isinstance(
result, str
) # Simple check to ensure output is string type
# Test caching mechanism to prevent re-running the same inputs

@ -13,8 +13,8 @@ from swarms.models.idefics import (
@pytest.fixture
def idefics_instance():
with patch(
"torch.cuda.is_available",
return_value=False): # Assuming tests are run on CPU for simplicity
"torch.cuda.is_available", return_value=False
): # Assuming tests are run on CPU for simplicity
instance = Idefics()
return instance
@ -36,8 +36,8 @@ def test_init_default(idefics_instance):
)
def test_init_device(device, expected):
with patch(
"torch.cuda.is_available",
return_value=True if expected == "cuda" else False,
"torch.cuda.is_available",
return_value=True if expected == "cuda" else False,
):
instance = Idefics(device=device)
assert instance.device == expected
@ -46,10 +46,14 @@ def test_init_device(device, expected):
# Test `run` method
def test_run(idefics_instance):
prompts = [["User: Test"]]
with patch.object(idefics_instance,
"processor") as mock_processor, patch.object(
idefics_instance, "model") as mock_model:
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
with patch.object(
idefics_instance, "processor"
) as mock_processor, patch.object(
idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"]
@ -61,10 +65,14 @@ def test_run(idefics_instance):
# Test `__call__` method (using the same logic as run for simplicity)
def test_call(idefics_instance):
prompts = [["User: Test"]]
with patch.object(idefics_instance,
"processor") as mock_processor, patch.object(
idefics_instance, "model") as mock_model:
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
with patch.object(
idefics_instance, "processor"
) as mock_processor, patch.object(
idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"]
@ -77,7 +85,9 @@ def test_call(idefics_instance):
def test_chat(idefics_instance):
user_input = "User: Hello"
response = "Model: Hi there!"
with patch.object(idefics_instance, "run", return_value=[response]):
with patch.object(
idefics_instance, "run", return_value=[response]
):
result = idefics_instance.chat(user_input)
assert result == response
@ -87,13 +97,16 @@ def test_chat(idefics_instance):
# Test `set_checkpoint` method
def test_set_checkpoint(idefics_instance):
new_checkpoint = "new_checkpoint"
with patch.object(IdeficsForVisionText2Text,
"from_pretrained") as mock_from_pretrained, patch.object(
AutoProcessor, "from_pretrained"):
with patch.object(
IdeficsForVisionText2Text, "from_pretrained"
) as mock_from_pretrained, patch.object(
AutoProcessor, "from_pretrained"
):
idefics_instance.set_checkpoint(new_checkpoint)
mock_from_pretrained.assert_called_with(new_checkpoint,
torch_dtype=torch.bfloat16)
mock_from_pretrained.assert_called_with(
new_checkpoint, torch_dtype=torch.bfloat16
)
# Test `set_device` method
@ -122,7 +135,7 @@ def test_clear_chat_history(idefics_instance):
# Exception Tests
def test_run_with_empty_prompts(idefics_instance):
with pytest.raises(
Exception
Exception
): # Replace Exception with the actual exception that may arise for an empty prompt.
idefics_instance.run([])
@ -130,10 +143,14 @@ def test_run_with_empty_prompts(idefics_instance):
# Test `run` method with batched_mode set to False
def test_run_batched_mode_false(idefics_instance):
task = "User: Test"
with patch.object(idefics_instance,
"processor") as mock_processor, patch.object(
idefics_instance, "model") as mock_model:
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
with patch.object(
idefics_instance, "processor"
) as mock_processor, patch.object(
idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {
"input_ids": torch.tensor([1, 2, 3])
}
mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"]
@ -146,7 +163,9 @@ def test_run_batched_mode_false(idefics_instance):
# Test `run` method with an exception
def test_run_with_exception(idefics_instance):
task = "User: Test"
with patch.object(idefics_instance, "processor") as mock_processor:
with patch.object(
idefics_instance, "processor"
) as mock_processor:
mock_processor.side_effect = Exception("Test exception")
with pytest.raises(Exception):
idefics_instance.run(task)
@ -155,21 +174,24 @@ def test_run_with_exception(idefics_instance):
# Test `set_model_name` method
def test_set_model_name(idefics_instance):
new_model_name = "new_model_name"
with patch.object(IdeficsForVisionText2Text,
"from_pretrained") as mock_from_pretrained, patch.object(
AutoProcessor, "from_pretrained"):
with patch.object(
IdeficsForVisionText2Text, "from_pretrained"
) as mock_from_pretrained, patch.object(
AutoProcessor, "from_pretrained"
):
idefics_instance.set_model_name(new_model_name)
assert idefics_instance.model_name == new_model_name
mock_from_pretrained.assert_called_with(new_model_name,
torch_dtype=torch.bfloat16)
mock_from_pretrained.assert_called_with(
new_model_name, torch_dtype=torch.bfloat16
)
# Test `__init__` method with device set to None
def test_init_device_none():
with patch(
"torch.cuda.is_available",
return_value=False,
"torch.cuda.is_available",
return_value=False,
):
instance = Idefics(device=None)
assert instance.device == "cpu"
@ -178,8 +200,8 @@ def test_init_device_none():
# Test `__init__` method with device set to "cuda"
def test_init_device_cuda():
with patch(
"torch.cuda.is_available",
return_value=True,
"torch.cuda.is_available",
return_value=True,
):
instance = Idefics(device="cuda")
assert instance.device == "cuda"

@ -16,7 +16,9 @@ def mock_image_request():
img_data = open(TEST_IMAGE_URL, "rb").read()
mock_resp = Mock()
mock_resp.raw = img_data
with patch.object(requests, "get", return_value=mock_resp) as _fixture:
with patch.object(
requests, "get", return_value=mock_resp
) as _fixture:
yield _fixture
@ -45,16 +47,18 @@ def test_get_image(mock_image_request):
# Test multimodal grounding
def test_multimodal_grounding(mock_image_request):
kosmos = Kosmos()
kosmos.multimodal_grounding("Find the red apple in the image.",
TEST_IMAGE_URL)
kosmos.multimodal_grounding(
"Find the red apple in the image.", TEST_IMAGE_URL
)
# TODO: Validate the result if possible
# Test referring expression comprehension
def test_referring_expression_comprehension(mock_image_request):
kosmos = Kosmos()
kosmos.referring_expression_comprehension("Show me the green bottle.",
TEST_IMAGE_URL)
kosmos.referring_expression_comprehension(
"Show me the green bottle.", TEST_IMAGE_URL
)
# TODO: Validate the result if possible
@ -89,7 +93,6 @@ IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=fo
# Mock response for requests.get()
class MockResponse:
@staticmethod
def json():
return {}
@ -108,23 +111,30 @@ def kosmos():
# Mocking the requests.get() method
@pytest.fixture
def mock_request_get(monkeypatch):
monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse())
monkeypatch.setattr(
requests, "get", lambda url, **kwargs: MockResponse()
)
@pytest.mark.usefixtures("mock_request_get")
def test_multimodal_grounding(kosmos):
kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1)
kosmos.multimodal_grounding(
"Find the red apple in the image.", IMG_URL1
)
@pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_comprehension(kosmos):
kosmos.referring_expression_comprehension("Show me the green bottle.",
IMG_URL2)
kosmos.referring_expression_comprehension(
"Show me the green bottle.", IMG_URL2
)
@pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_generation(kosmos):
kosmos.referring_expression_generation("It is on the table.", IMG_URL3)
kosmos.referring_expression_generation(
"It is on the table.", IMG_URL3
)
@pytest.mark.usefixtures("mock_request_get")
@ -144,13 +154,16 @@ def test_grounded_image_captioning_detailed(kosmos):
@pytest.mark.usefixtures("mock_request_get")
def test_multimodal_grounding_2(kosmos):
kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2)
kosmos.multimodal_grounding(
"Find the yellow fruit in the image.", IMG_URL2
)
@pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_comprehension_2(kosmos):
kosmos.referring_expression_comprehension("Where is the water bottle?",
IMG_URL3)
kosmos.referring_expression_comprehension(
"Where is the water bottle?", IMG_URL3
)
@pytest.mark.usefixtures("mock_request_get")

@ -18,7 +18,6 @@ def test_llama_model_loading(llama_caller):
# Test adding and calling custom functions
def test_llama_custom_function(llama_caller):
def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}"
@ -40,11 +39,13 @@ def test_llama_custom_function(llama_caller):
],
)
result = llama_caller.call_function("sample_function",
arg1="arg1_value",
arg2="arg2_value")
result = llama_caller.call_function(
"sample_function", arg1="arg1_value", arg2="arg2_value"
)
assert (
result == "Sample function called with args: arg1_value, arg2_value")
result
== "Sample function called with args: arg1_value, arg2_value"
)
# Test streaming user prompts
@ -63,7 +64,6 @@ def test_llama_custom_function_not_found(llama_caller):
# Test invalid arguments for custom function
def test_llama_custom_function_invalid_arguments(llama_caller):
def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}"
@ -86,7 +86,9 @@ def test_llama_custom_function_invalid_arguments(llama_caller):
)
with pytest.raises(TypeError):
llama_caller.call_function("sample_function", arg1="arg1_value")
llama_caller.call_function(
"sample_function", arg1="arg1_value"
)
# Test streaming with custom runtime

@ -21,18 +21,22 @@ def test_mixtral_run(mock_model, mock_tokenizer):
mixtral = Mixtral()
mock_tokenizer_instance = MagicMock()
mock_model_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance)
mock_tokenizer.from_pretrained.return_value = (
mock_tokenizer_instance
)
mock_model.from_pretrained.return_value = mock_model_instance
mock_tokenizer_instance.return_tensors = "pt"
mock_model_instance.generate.return_value = [101, 102, 103]
mock_tokenizer_instance.decode.return_value = "Generated text"
result = mixtral.run("Test task")
assert result == "Generated text"
mock_tokenizer_instance.assert_called_once_with("Test task",
return_tensors="pt")
mock_tokenizer_instance.assert_called_once_with(
"Test task", return_tensors="pt"
)
mock_model_instance.generate.assert_called_once()
mock_tokenizer_instance.decode.assert_called_once_with(
[101, 102, 103], skip_special_tokens=True)
[101, 102, 103], skip_special_tokens=True
)
@patch("swarms.models.mixtral.AutoTokenizer")
@ -41,7 +45,9 @@ def test_mixtral_run_error(mock_model, mock_tokenizer):
mixtral = Mixtral()
mock_tokenizer_instance = MagicMock()
mock_model_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = (mock_tokenizer_instance)
mock_tokenizer.from_pretrained.return_value = (
mock_tokenizer_instance
)
mock_model.from_pretrained.return_value = mock_model_instance
mock_tokenizer_instance.return_tensors = "pt"
mock_model_instance.generate.side_effect = Exception("Test error")

@ -25,10 +25,14 @@ def test_mpt7b_run():
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
output = mpt.run("generate", "Once upon a time in a land far, far away...")
output = mpt.run(
"generate", "Once upon a time in a land far, far away..."
)
assert isinstance(output, str)
assert output.startswith("Once upon a time in a land far, far away...")
assert output.startswith(
"Once upon a time in a land far, far away..."
)
def test_mpt7b_run_invalid_task():
@ -51,10 +55,14 @@ def test_mpt7b_generate():
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
output = mpt.generate("Once upon a time in a land far, far away...")
output = mpt.generate(
"Once upon a time in a land far, far away..."
)
assert isinstance(output, str)
assert output.startswith("Once upon a time in a land far, far away...")
assert output.startswith(
"Once upon a time in a land far, far away..."
)
def test_mpt7b_batch_generate():

@ -43,7 +43,9 @@ def test_model_initialization(setup_nougat):
"cuda_available, expected_device",
[(True, "cuda"), (False, "cpu")],
)
def test_device_initialization(cuda_available, expected_device, monkeypatch):
def test_device_initialization(
cuda_available, expected_device, monkeypatch
):
monkeypatch.setattr(
torch,
"cuda",
@ -72,7 +74,9 @@ def test_get_image_invalid_path(setup_nougat):
(10, 50),
],
)
def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens):
def test_model_call_with_diff_params(
setup_nougat, min_len, max_tokens
):
setup_nougat.min_length = min_len
setup_nougat.max_new_tokens = max_tokens
@ -103,11 +107,11 @@ def test_model_call_mocked_output(setup_nougat):
def mock_processor_and_model():
"""Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior."""
with patch(
"transformers.NougatProcessor.from_pretrained",
return_value=Mock(),
"transformers.NougatProcessor.from_pretrained",
return_value=Mock(),
), patch(
"transformers.VisionEncoderDecoderModel.from_pretrained",
return_value=Mock(),
"transformers.VisionEncoderDecoderModel.from_pretrained",
return_value=Mock(),
):
yield
@ -118,7 +122,8 @@ def test_nougat_with_sample_image_1(setup_nougat):
os.path.join(
"sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
))
)
)
assert isinstance(result, str)
@ -135,7 +140,8 @@ def test_nougat_min_length_param(setup_nougat):
os.path.join(
"sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
))
)
)
assert isinstance(result, str)
@ -146,7 +152,8 @@ def test_nougat_max_new_tokens_param(setup_nougat):
os.path.join(
"sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
))
)
)
assert isinstance(result, str)
@ -157,13 +164,16 @@ def test_nougat_different_model_path(setup_nougat):
os.path.join(
"sample_images",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
))
)
)
assert isinstance(result, str)
@pytest.mark.usefixtures("mock_processor_and_model")
def test_nougat_bad_image_path(setup_nougat):
with pytest.raises(Exception): # Adjust the exception type accordingly.
with pytest.raises(
Exception
): # Adjust the exception type accordingly.
setup_nougat("bad_image_path.png")
@ -173,7 +183,8 @@ def test_nougat_image_large_size(setup_nougat):
os.path.join(
"sample_images",
"https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D",
))
)
)
assert isinstance(result, str)
@ -183,7 +194,8 @@ def test_nougat_image_small_size(setup_nougat):
os.path.join(
"sample_images",
"https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D",
))
)
)
assert isinstance(result, str)
@ -193,7 +205,8 @@ def test_nougat_image_varied_content(setup_nougat):
os.path.join(
"sample_images",
"https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D",
))
)
)
assert isinstance(result, str)
@ -203,5 +216,6 @@ def test_nougat_image_with_metadata(setup_nougat):
os.path.join(
"sample_images",
"https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D",
))
)
)
assert isinstance(result, str)

@ -4,16 +4,18 @@ from swarms.models.qwen import QwenVLMultiModal
def test_post_init():
with patch("swarms.models.qwen.AutoTokenizer.from_pretrained"
) as mock_tokenizer, patch(
"swarms.models.qwen.AutoModelForCausalLM.from_pretrained"
) as mock_model:
with patch(
"swarms.models.qwen.AutoTokenizer.from_pretrained"
) as mock_tokenizer, patch(
"swarms.models.qwen.AutoModelForCausalLM.from_pretrained"
) as mock_model:
mock_tokenizer.return_value = Mock()
mock_model.return_value = Mock()
model = QwenVLMultiModal()
mock_tokenizer.assert_called_once_with(model.model_name,
trust_remote_code=True)
mock_tokenizer.assert_called_once_with(
model.model_name, trust_remote_code=True
)
mock_model.assert_called_once_with(
model.model_name,
device_map=model.device,
@ -23,31 +25,37 @@ def test_post_init():
def test_run():
with patch(
"swarms.models.qwen.AutoTokenizer.from_list_format"
"swarms.models.qwen.AutoTokenizer.from_list_format"
) as mock_format, patch(
"swarms.models.qwen.AutoTokenizer.__call__") as mock_call, patch(
"swarms.models.qwen.AutoModelForCausalLM.generate"
) as mock_generate, patch(
"swarms.models.qwen.AutoTokenizer.decode") as mock_decode:
"swarms.models.qwen.AutoTokenizer.__call__"
) as mock_call, patch(
"swarms.models.qwen.AutoModelForCausalLM.generate"
) as mock_generate, patch(
"swarms.models.qwen.AutoTokenizer.decode"
) as mock_decode:
mock_format.return_value = Mock()
mock_call.return_value = Mock()
mock_generate.return_value = Mock()
mock_decode.return_value = "response"
model = QwenVLMultiModal()
response = model.run("Hello, how are you?",
"https://example.com/image.jpg")
response = model.run(
"Hello, how are you?", "https://example.com/image.jpg"
)
assert response == "response"
def test_chat():
with patch("swarms.models.qwen.AutoModelForCausalLM.chat") as mock_chat:
with patch(
"swarms.models.qwen.AutoModelForCausalLM.chat"
) as mock_chat:
mock_chat.return_value = ("response", ["history"])
model = QwenVLMultiModal()
response, history = model.chat("Hello, how are you?",
"https://example.com/image.jpg")
response, history = model.chat(
"Hello, how are you?", "https://example.com/image.jpg"
)
assert response == "response"
assert history == ["history"]

@ -16,11 +16,16 @@ def speecht5_model():
def test_speecht5_init(speecht5_model):
assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__)
assert isinstance(
speecht5_model.processor, SpeechT5.processor.__class__
)
assert isinstance(speecht5_model.model, SpeechT5.model.__class__)
assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__)
assert isinstance(speecht5_model.embeddings_dataset,
torch.utils.data.Dataset)
assert isinstance(
speecht5_model.vocoder, SpeechT5.vocoder.__class__
)
assert isinstance(
speecht5_model.embeddings_dataset, torch.utils.data.Dataset
)
def test_speecht5_call(speecht5_model):
@ -44,7 +49,10 @@ def test_speecht5_set_model(speecht5_model):
speecht5_model.set_model(new_model_name)
assert speecht5_model.model_name == new_model_name
assert speecht5_model.processor.model_name == new_model_name
assert (speecht5_model.model.config.model_name_or_path == new_model_name)
assert (
speecht5_model.model.config.model_name_or_path
== new_model_name
)
speecht5_model.set_model(old_model_name) # Restore original model
@ -54,8 +62,12 @@ def test_speecht5_set_vocoder(speecht5_model):
speecht5_model.set_vocoder(new_vocoder_name)
assert speecht5_model.vocoder_name == new_vocoder_name
assert (
speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name)
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder
speecht5_model.vocoder.config.model_name_or_path
== new_vocoder_name
)
speecht5_model.set_vocoder(
old_vocoder_name
) # Restore original vocoder
def test_speecht5_set_embeddings_dataset(speecht5_model):
@ -63,10 +75,12 @@ def test_speecht5_set_embeddings_dataset(speecht5_model):
new_dataset_name = "Matthijs/cmu-arctic-xvectors-test"
speecht5_model.set_embeddings_dataset(new_dataset_name)
assert speecht5_model.dataset_name == new_dataset_name
assert isinstance(speecht5_model.embeddings_dataset,
torch.utils.data.Dataset)
assert isinstance(
speecht5_model.embeddings_dataset, torch.utils.data.Dataset
)
speecht5_model.set_embeddings_dataset(
old_dataset_name) # Restore original dataset
old_dataset_name
) # Restore original dataset
def test_speecht5_get_sampling_rate(speecht5_model):
@ -98,7 +112,9 @@ def test_speecht5_change_dataset_split(speecht5_model):
def test_speecht5_load_custom_embedding(speecht5_model):
xvector = [0.1, 0.2, 0.3, 0.4, 0.5]
embedding = speecht5_model.load_custom_embedding(xvector)
assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0)))
assert torch.all(
torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))
)
def test_speecht5_with_different_speakers(speecht5_model):
@ -109,7 +125,9 @@ def test_speecht5_with_different_speakers(speecht5_model):
assert isinstance(speech, torch.Tensor)
def test_speecht5_save_speech_with_different_extensions(speecht5_model,):
def test_speecht5_save_speech_with_different_extensions(
speecht5_model,
):
text = "Hello, how are you?"
speech = speecht5_model(text)
extensions = [".wav", ".flac"]
@ -144,4 +162,6 @@ def test_speecht5_change_vocoder_model(speecht5_model):
speecht5_model.set_vocoder(new_vocoder_name)
speech = speecht5_model(text)
assert isinstance(speech, torch.Tensor)
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder
speecht5_model.set_vocoder(
old_vocoder_name
) # Restore original vocoder

@ -21,30 +21,36 @@ def test_ssd1b_call(ssd1b_model):
image_url = ssd1b_model(task, neg_prompt)
assert isinstance(image_url, str)
assert image_url.startswith(
"https://") # Assuming it starts with "https://"
"https://"
) # Assuming it starts with "https://"
# Add more tests for various aspects of the class and methods
# Example of a parameterized test for different tasks
@pytest.mark.parametrize("task",
["A painting of a cat", "A painting of a tree"])
@pytest.mark.parametrize(
"task", ["A painting of a cat", "A painting of a tree"]
)
def test_ssd1b_parameterized_task(ssd1b_model, task):
image_url = ssd1b_model(task)
assert isinstance(image_url, str)
assert image_url.startswith(
"https://") # Assuming it starts with "https://"
"https://"
) # Assuming it starts with "https://"
# Example of a test using mocks to isolate units of code
def test_ssd1b_with_mock(ssd1b_model, mocker):
mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline
mocker.patch(
"your_module.StableDiffusionXLPipeline"
) # Mock the pipeline
task = "A painting of a cat"
image_url = ssd1b_model(task)
assert isinstance(image_url, str)
assert image_url.startswith(
"https://") # Assuming it starts with "https://"
"https://"
) # Assuming it starts with "https://"
def test_ssd1b_call_with_cache(ssd1b_model):
@ -62,8 +68,9 @@ def test_ssd1b_invalid_task(ssd1b_model):
def test_ssd1b_failed_api_call(ssd1b_model, mocker):
mocker.patch("your_module.StableDiffusionXLPipeline"
) # Mock the pipeline to raise an exception
mocker.patch(
"your_module.StableDiffusionXLPipeline"
) # Mock the pipeline to raise an exception
task = "A painting of a cat"
with pytest.raises(Exception):
ssd1b_model(task)

@ -19,16 +19,18 @@ def test_timm_model_init():
def test_timm_model_call():
with patch("swarms.models.timm.create_model") as mock_create_model:
with patch(
"swarms.models.timm.create_model"
) as mock_create_model:
model_name = "resnet18"
pretrained = True
in_chans = 3
timm_model = TimmModel(model_name, pretrained, in_chans)
task = torch.rand(1, in_chans, 224, 224)
result = timm_model(task)
mock_create_model.assert_called_once_with(model_name,
pretrained=pretrained,
in_chans=in_chans)
mock_create_model.assert_called_once_with(
model_name, pretrained=pretrained, in_chans=in_chans
)
assert result == mock_create_model.return_value(task)

@ -22,14 +22,17 @@ def test_create_model(sample_model_info):
def test_call(sample_model_info):
model_handler = TimmModel()
input_tensor = torch.randn(1, 3, 224, 224)
output_shape = model_handler.__call__(sample_model_info, input_tensor)
output_shape = model_handler.__call__(
sample_model_info, input_tensor
)
assert isinstance(output_shape, torch.Size)
def test_get_supported_models_mock():
model_handler = TimmModel()
model_handler._get_supported_models = Mock(
return_value=["resnet18", "resnet50"])
return_value=["resnet18", "resnet50"]
)
supported_models = model_handler._get_supported_models()
assert supported_models == ["resnet18", "resnet50"]

@ -55,11 +55,7 @@ def test_init_custom_params(mock_api_key):
def test_run_success(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {
"choices": [{
"message": {
"content": "Generated response"
}
}]
"choices": [{"message": {"content": "Generated response"}}]
}
mock_post.return_value = mock_response
@ -73,7 +69,8 @@ def test_run_success(mock_post, mock_api_key):
@patch("swarms.models.together_model.requests.post")
def test_run_failure(mock_post, mock_api_key):
mock_post.side_effect = requests.exceptions.RequestException(
"Request failed")
"Request failed"
)
model = TogetherLLM()
task = "What is the color of the object?"
@ -92,7 +89,9 @@ def test_run_with_logging_enabled(caplog, mock_api_key):
assert "Sending request to" in caplog.text
@pytest.mark.parametrize("invalid_input", [None, 123, ["list", "of", "items"]])
@pytest.mark.parametrize(
"invalid_input", [None, 123, ["list", "of", "items"]]
)
def test_invalid_task_input(invalid_input, mock_api_key):
model = TogetherLLM()
response = model.run(invalid_input)
@ -104,11 +103,7 @@ def test_invalid_task_input(invalid_input, mock_api_key):
def test_run_streaming_enabled(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {
"choices": [{
"message": {
"content": "Generated response"
}
}]
"choices": [{"message": {"content": "Generated response"}}]
}
mock_post.return_value = mock_response

@ -20,7 +20,9 @@ def test_ultralytics_call():
args = (1, 2, 3)
kwargs = {"a": "A", "b": "B"}
result = ultralytics(task, *args, **kwargs)
mock_yolo.return_value.assert_called_once_with(task, *args, **kwargs)
mock_yolo.return_value.assert_called_once_with(
task, *args, **kwargs
)
assert result == mock_yolo.return_value.return_value

@ -21,13 +21,17 @@ def test_vilt_initialization(vilt_instance):
# 2. Test Model Predictions
@patch.object(requests, "get")
@patch.object(Image, "open")
def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance):
def test_vilt_prediction(
mock_image_open, mock_requests_get, vilt_instance
):
mock_image = Mock()
mock_image_open.return_value = mock_image
mock_requests_get.return_value.raw = Mock()
# It's a mock response, so no real answer expected
with pytest.raises(Exception): # Ensure exception is more specific
with pytest.raises(
Exception
): # Ensure exception is more specific
vilt_instance(
"What is this image",
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
@ -62,7 +66,9 @@ def test_vilt_network_exception(vilt_instance):
],
)
def test_vilt_various_inputs(text, image_url, vilt_instance):
with pytest.raises(Exception): # Again, ensure exception is more specific
with pytest.raises(
Exception
): # Again, ensure exception is more specific
vilt_instance(text, image_url)

@ -32,7 +32,9 @@ def test_yi34b_generate_text_with_length(yi34b_model, max_length):
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
def test_yi34b_generate_text_with_temperature(
yi34b_model, temperature
):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, temperature=temperature)
assert isinstance(generated_text, str)
@ -40,24 +42,27 @@ def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
prompt = None # Invalid prompt
with pytest.raises(ValueError,
match="Input prompt must be a non-empty string"):
with pytest.raises(
ValueError, match="Input prompt must be a non-empty string"
):
yi34b_model(prompt)
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
prompt = "There's a place where time stands still."
max_length = -1 # Invalid max_length
with pytest.raises(ValueError,
match="max_length must be a positive integer"):
with pytest.raises(
ValueError, match="max_length must be a positive integer"
):
yi34b_model(prompt, max_length=max_length)
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
prompt = "There's a place where time stands still."
temperature = 2.0 # Invalid temperature
with pytest.raises(ValueError,
match="temperature must be between 0.01 and 1.0"):
with pytest.raises(
ValueError, match="temperature must be between 0.01 and 1.0"
):
yi34b_model(prompt, temperature=temperature)
@ -78,32 +83,40 @@ def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
prompt = "There's a place where time stands still."
top_k = -1 # Invalid top_k
with pytest.raises(ValueError,
match="top_k must be a non-negative integer"):
with pytest.raises(
ValueError, match="top_k must be a non-negative integer"
):
yi34b_model(prompt, top_k=top_k)
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
prompt = "There's a place where time stands still."
top_p = 1.5 # Invalid top_p
with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"):
with pytest.raises(
ValueError, match="top_p must be between 0.0 and 1.0"
):
yi34b_model(prompt, top_p=top_p)
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
def test_yi34b_generate_text_with_repitition_penalty(yi34b_model,
repitition_penalty):
def test_yi34b_generate_text_with_repitition_penalty(
yi34b_model, repitition_penalty
):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty)
generated_text = yi34b_model(
prompt, repitition_penalty=repitition_penalty
)
assert isinstance(generated_text, str)
def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model,):
def test_yi34b_generate_text_with_invalid_repitition_penalty(
yi34b_model,
):
prompt = "There's a place where time stands still."
repitition_penalty = 0.0 # Invalid repitition_penalty
with pytest.raises(
ValueError,
match="repitition_penalty must be a positive float",
ValueError,
match="repitition_penalty must be a positive float",
):
yi34b_model(prompt, repitition_penalty=repitition_penalty)

@ -25,11 +25,16 @@ def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance)
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames")
mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
mock_pipeline_instance.enable_vae_slicing.assert_called_once()
mock_pipeline_instance.enable_forward_chunking.assert_called_once_with(
chunk_size=1, dim=1)
chunk_size=1, dim=1
)
result = zeroscope.forward("Test task")
assert result == "Generated frames"
mock_pipeline_instance.assert_called_once_with(
@ -46,8 +51,12 @@ def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance)
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames")
mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
mock_pipeline_instance.side_effect = Exception("Test error")
with pytest.raises(Exception, match="Test error"):
zeroscope.forward("Test task")
@ -58,8 +67,12 @@ def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance)
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames")
mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
result = zeroscope.__call__("Test task")
assert result == "Generated frames"
mock_pipeline_instance.assert_called_once_with(
@ -76,8 +89,12 @@ def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance)
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames")
mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
mock_pipeline_instance.side_effect = Exception("Test error")
with pytest.raises(Exception, match="Test error"):
zeroscope.__call__("Test task")
@ -88,8 +105,12 @@ def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline):
zeroscope = ZeroscopeTTV()
mock_pipeline_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = (mock_pipeline_instance)
mock_pipeline_instance.return_value = MagicMock(frames="Generated frames")
mock_pipeline.from_pretrained.return_value = (
mock_pipeline_instance
)
mock_pipeline_instance.return_value = MagicMock(
frames="Generated frames"
)
result = zeroscope.save_video_path("Test video path")
assert result == "Test video path"
mock_pipeline_instance.assert_called_once_with(

@ -18,7 +18,9 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
# Mocks and Fixtures
@pytest.fixture
def mocked_llm():
return OpenAIChat(openai_api_key=openai_api_key,)
return OpenAIChat(
openai_api_key=openai_api_key,
)
@pytest.fixture
@ -63,12 +65,15 @@ def test_provide_feedback(basic_flow):
@patch("time.sleep", return_value=None) # to speed up tests
def test_run_without_stopping_condition(mocked_sleep, basic_flow):
response = basic_flow.run("Test task")
assert (response == "Test task"
) # since our mocked llm doesn't modify the response
assert (
response == "Test task"
) # since our mocked llm doesn't modify the response
@patch("time.sleep", return_value=None) # to speed up tests
def test_run_with_stopping_condition(mocked_sleep, flow_with_condition):
def test_run_with_stopping_condition(
mocked_sleep, flow_with_condition
):
response = flow_with_condition.run("Stop")
assert response == "Stop"
@ -108,7 +113,6 @@ def test_env_variable_handling(monkeypatch):
# Test initializing the agent with different stopping conditions
def test_flow_with_custom_stopping_condition(mocked_llm):
def stopping_condition(x):
return "terminate" in x.lower()
@ -129,7 +133,9 @@ def test_flow_call(basic_flow):
# Test formatting the prompt
def test_format_prompt(basic_flow):
formatted_prompt = basic_flow.format_prompt("Hello {name}", name="John")
formatted_prompt = basic_flow.format_prompt(
"Hello {name}", name="John"
)
assert formatted_prompt == "Hello John"
@ -158,15 +164,9 @@ def test_interactive_mode(basic_flow):
# Test bulk run with varied inputs
def test_bulk_run_varied_inputs(basic_flow):
inputs = [
{
"task": "Test1"
},
{
"task": "Test2"
},
{
"task": "Stop now"
},
{"task": "Test1"},
{"task": "Test2"},
{"task": "Stop now"},
]
responses = basic_flow.bulk_run(inputs)
assert responses == ["Test1", "Test2", "Stop now"]
@ -191,9 +191,12 @@ def test_save_different_memory(basic_flow, tmp_path):
# Test the stopping condition check
def test_check_stopping_condition(flow_with_condition):
assert flow_with_condition._check_stopping_condition("Stop this process")
assert flow_with_condition._check_stopping_condition(
"Stop this process"
)
assert not flow_with_condition._check_stopping_condition(
"Continue the task")
"Continue the task"
)
# Test without providing max loops (default value should be 5)
@ -249,7 +252,9 @@ def test_different_retry_intervals(mocked_sleep, basic_flow):
# Test invoking the agent with additional kwargs
@patch("time.sleep", return_value=None)
def test_flow_call_with_kwargs(mocked_sleep, basic_flow):
response = basic_flow("Test call", param1="value1", param2="value2")
response = basic_flow(
"Test call", param1="value1", param2="value2"
)
assert response == "Test call"
@ -284,7 +289,9 @@ def test_stopping_token_in_response(mocked_sleep, basic_flow):
def flow_instance():
# Create an instance of the Agent class with required parameters for testing
# You may need to adjust this based on your actual class initialization
llm = OpenAIChat(openai_api_key=openai_api_key,)
llm = OpenAIChat(
openai_api_key=openai_api_key,
)
agent = Agent(
llm=llm,
max_loops=5,
@ -331,7 +338,9 @@ def test_flow_autosave(flow_instance):
def test_flow_response_filtering(flow_instance):
# Test the response filtering functionality
flow_instance.add_response_filter("filter_this")
response = flow_instance.filtered_run("This message should filter_this")
response = flow_instance.filtered_run(
"This message should filter_this"
)
assert "filter_this" not in response
@ -391,8 +400,11 @@ def test_flow_response_length(flow_instance):
# Test checking the length of the response
response = flow_instance.run(
"Generate a 10,000 word long blog on mental clarity and the"
" benefits of meditation.")
assert (len(response) > flow_instance.get_response_length_threshold())
" benefits of meditation."
)
assert (
len(response) > flow_instance.get_response_length_threshold()
)
def test_flow_set_response_length_threshold(flow_instance):
@ -481,7 +493,9 @@ def test_flow_get_conversation_log(flow_instance):
flow_instance.run("Message 1")
flow_instance.run("Message 2")
conversation_log = flow_instance.get_conversation_log()
assert (len(conversation_log) == 4) # Including system and user messages
assert (
len(conversation_log) == 4
) # Including system and user messages
def test_flow_clear_conversation_log(flow_instance):
@ -565,18 +579,37 @@ def test_flow_rollback(flow_instance):
flow_instance.change_prompt("New prompt")
flow_instance.get_state()
flow_instance.rollback_to_state(state1)
assert (flow_instance.get_current_prompt() == state1["current_prompt"])
assert (
flow_instance.get_current_prompt() == state1["current_prompt"]
)
assert flow_instance.get_instructions() == state1["instructions"]
assert (flow_instance.get_user_messages() == state1["user_messages"])
assert (flow_instance.get_response_history() == state1["response_history"])
assert (flow_instance.get_conversation_log() == state1["conversation_log"])
assert (flow_instance.is_dynamic_pacing_enabled() ==
state1["dynamic_pacing_enabled"])
assert (flow_instance.get_response_length_threshold() ==
state1["response_length_threshold"])
assert (flow_instance.get_response_filters() == state1["response_filters"])
assert (
flow_instance.get_user_messages() == state1["user_messages"]
)
assert (
flow_instance.get_response_history()
== state1["response_history"]
)
assert (
flow_instance.get_conversation_log()
== state1["conversation_log"]
)
assert (
flow_instance.is_dynamic_pacing_enabled()
== state1["dynamic_pacing_enabled"]
)
assert (
flow_instance.get_response_length_threshold()
== state1["response_length_threshold"]
)
assert (
flow_instance.get_response_filters()
== state1["response_filters"]
)
assert flow_instance.get_max_loops() == state1["max_loops"]
assert (flow_instance.get_autosave_path() == state1["autosave_path"])
assert (
flow_instance.get_autosave_path() == state1["autosave_path"]
)
assert flow_instance.get_state() == state1
@ -585,7 +618,8 @@ def test_flow_contextual_intent(flow_instance):
flow_instance.add_context("location", "New York")
flow_instance.add_context("time", "tomorrow")
response = flow_instance.run(
"What's the weather like in {location} at {time}?")
"What's the weather like in {location} at {time}?"
)
assert "New York" in response
assert "tomorrow" in response
@ -593,9 +627,13 @@ def test_flow_contextual_intent(flow_instance):
def test_flow_contextual_intent_override(flow_instance):
# Test contextual intent override
flow_instance.add_context("location", "New York")
response1 = flow_instance.run("What's the weather like in {location}?")
response1 = flow_instance.run(
"What's the weather like in {location}?"
)
flow_instance.add_context("location", "Los Angeles")
response2 = flow_instance.run("What's the weather like in {location}?")
response2 = flow_instance.run(
"What's the weather like in {location}?"
)
assert "New York" in response1
assert "Los Angeles" in response2
@ -603,9 +641,13 @@ def test_flow_contextual_intent_override(flow_instance):
def test_flow_contextual_intent_reset(flow_instance):
# Test resetting contextual intent
flow_instance.add_context("location", "New York")
response1 = flow_instance.run("What's the weather like in {location}?")
response1 = flow_instance.run(
"What's the weather like in {location}?"
)
flow_instance.reset_context()
response2 = flow_instance.run("What's the weather like in {location}?")
response2 = flow_instance.run(
"What's the weather like in {location}?"
)
assert "New York" in response1
assert "New York" in response2
@ -630,7 +672,9 @@ def test_flow_non_interruptible(flow_instance):
def test_flow_timeout(flow_instance):
# Test conversation timeout
flow_instance.timeout = 60 # Set a timeout of 60 seconds
response = flow_instance.run("This should take some time to respond.")
response = flow_instance.run(
"This should take some time to respond."
)
assert "Timed out" in response
assert flow_instance.is_timed_out() is True
@ -679,14 +723,20 @@ def test_flow_save_and_load_conversation(flow_instance):
def test_flow_inject_custom_system_message(flow_instance):
# Test injecting a custom system message into the conversation
flow_instance.inject_custom_system_message("Custom system message")
assert ("Custom system message" in flow_instance.get_message_history())
flow_instance.inject_custom_system_message(
"Custom system message"
)
assert (
"Custom system message" in flow_instance.get_message_history()
)
def test_flow_inject_custom_user_message(flow_instance):
# Test injecting a custom user message into the conversation
flow_instance.inject_custom_user_message("Custom user message")
assert ("Custom user message" in flow_instance.get_message_history())
assert (
"Custom user message" in flow_instance.get_message_history()
)
def test_flow_inject_custom_response(flow_instance):
@ -697,28 +747,45 @@ def test_flow_inject_custom_response(flow_instance):
def test_flow_clear_injected_messages(flow_instance):
# Test clearing injected messages from the conversation
flow_instance.inject_custom_system_message("Custom system message")
flow_instance.inject_custom_system_message(
"Custom system message"
)
flow_instance.inject_custom_user_message("Custom user message")
flow_instance.inject_custom_response("Custom response")
flow_instance.clear_injected_messages()
assert ("Custom system message" not in flow_instance.get_message_history())
assert ("Custom user message" not in flow_instance.get_message_history())
assert ("Custom response" not in flow_instance.get_message_history())
assert (
"Custom system message"
not in flow_instance.get_message_history()
)
assert (
"Custom user message"
not in flow_instance.get_message_history()
)
assert (
"Custom response" not in flow_instance.get_message_history()
)
def test_flow_disable_message_history(flow_instance):
# Test disabling message history recording
flow_instance.disable_message_history()
response = flow_instance.run(
"This message should not be recorded in history.")
assert ("This message should not be recorded in history." in response)
assert (len(flow_instance.get_message_history()) == 0) # History is empty
"This message should not be recorded in history."
)
assert (
"This message should not be recorded in history." in response
)
assert (
len(flow_instance.get_message_history()) == 0
) # History is empty
def test_flow_enable_message_history(flow_instance):
# Test enabling message history recording
flow_instance.enable_message_history()
response = flow_instance.run("This message should be recorded in history.")
response = flow_instance.run(
"This message should be recorded in history."
)
assert "This message should be recorded in history." in response
assert len(flow_instance.get_message_history()) == 1
@ -728,7 +795,9 @@ def test_flow_custom_logger(flow_instance):
custom_logger = logger # Replace with your custom logger class
flow_instance.set_logger(custom_logger)
response = flow_instance.run("Custom logger test")
assert ("Logged using custom logger" in response) # Verify logging message
assert (
"Logged using custom logger" in response
) # Verify logging message
def test_flow_batch_processing(flow_instance):
@ -802,35 +871,43 @@ def test_flow_input_validation(flow_instance):
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
"") # Empty delimiter, should raise ValueError
""
) # Empty delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
None) # None delimiter, should raise ValueError
None
) # None delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
123) # Invalid delimiter type, should raise ValueError
123
) # Invalid delimiter type, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_logger(
"invalid_logger") # Invalid logger type, should raise ValueError
"invalid_logger"
) # Invalid logger type, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context(None,
"value") # None key, should raise ValueError
flow_instance.add_context(
None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context("key",
None) # None value, should raise ValueError
flow_instance.add_context(
"key", None
) # None value, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context(
None, "value") # None key, should raise ValueError
None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context(
"key", None) # None value, should raise ValueError
"key", None
) # None value, should raise ValueError
def test_flow_conversation_reset(flow_instance):
@ -857,7 +934,6 @@ def test_flow_conversation_persistence(flow_instance):
def test_flow_custom_event_listener(flow_instance):
# Test custom event listener
class CustomEventListener:
def on_message_received(self, message):
pass
@ -869,10 +945,10 @@ def test_flow_custom_event_listener(flow_instance):
# Ensure that the custom event listener methods are called during a conversation
with mock.patch.object(
custom_event_listener,
"on_message_received") as mock_received, mock.patch.object(
custom_event_listener,
"on_response_generated") as mock_response:
custom_event_listener, "on_message_received"
) as mock_received, mock.patch.object(
custom_event_listener, "on_response_generated"
) as mock_response:
flow_instance.run("Message 1")
mock_received.assert_called_once()
mock_response.assert_called_once()
@ -881,7 +957,6 @@ def test_flow_custom_event_listener(flow_instance):
def test_flow_multiple_event_listeners(flow_instance):
# Test multiple event listeners
class FirstEventListener:
def on_message_received(self, message):
pass
@ -889,7 +964,6 @@ def test_flow_multiple_event_listeners(flow_instance):
pass
class SecondEventListener:
def on_message_received(self, message):
pass
@ -903,14 +977,14 @@ def test_flow_multiple_event_listeners(flow_instance):
# Ensure that both event listeners receive events during a conversation
with mock.patch.object(
first_event_listener,
"on_message_received") as mock_first_received, mock.patch.object(
first_event_listener, "on_response_generated"
) as mock_first_response, mock.patch.object(
second_event_listener, "on_message_received"
) as mock_second_received, mock.patch.object(
second_event_listener,
"on_response_generated") as mock_second_response:
first_event_listener, "on_message_received"
) as mock_first_received, mock.patch.object(
first_event_listener, "on_response_generated"
) as mock_first_response, mock.patch.object(
second_event_listener, "on_message_received"
) as mock_second_received, mock.patch.object(
second_event_listener, "on_response_generated"
) as mock_second_response:
flow_instance.run("Message 1")
mock_first_received.assert_called_once()
mock_first_response.assert_called_once()
@ -923,31 +997,38 @@ def test_flow_error_handling(flow_instance):
# Test error handling and exceptions
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
"") # Empty delimiter, should raise ValueError
""
) # Empty delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
None) # None delimiter, should raise ValueError
None
) # None delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_logger(
"invalid_logger") # Invalid logger type, should raise ValueError
"invalid_logger"
) # Invalid logger type, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context(None,
"value") # None key, should raise ValueError
flow_instance.add_context(
None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context("key",
None) # None value, should raise ValueError
flow_instance.add_context(
"key", None
) # None value, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context(
None, "value") # None key, should raise ValueError
None, "value"
) # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context(
"key", None) # None value, should raise ValueError
"key", None
) # None value, should raise ValueError
def test_flow_context_operations(flow_instance):
@ -984,8 +1065,14 @@ def test_flow_custom_response(flow_instance):
flow_instance.set_response_generator(custom_response_generator)
assert flow_instance.run("Hello") == "Hi there!"
assert (flow_instance.run("How are you?") == "I'm doing well, thank you.")
assert (flow_instance.run("What's your name?") == "I don't understand.")
assert (
flow_instance.run("How are you?")
== "I'm doing well, thank you."
)
assert (
flow_instance.run("What's your name?")
== "I don't understand."
)
def test_flow_message_validation(flow_instance):
@ -996,8 +1083,12 @@ def test_flow_message_validation(flow_instance):
flow_instance.set_message_validator(custom_message_validator)
assert flow_instance.run("Valid message") is not None
assert (flow_instance.run("") is None) # Empty message should be rejected
assert (flow_instance.run(None) is None) # None message should be rejected
assert (
flow_instance.run("") is None
) # Empty message should be rejected
assert (
flow_instance.run(None) is None
) # None message should be rejected
def test_flow_custom_logging(flow_instance):
@ -1022,10 +1113,15 @@ def test_flow_complex_use_case(flow_instance):
flow_instance.add_context("user_id", "12345")
flow_instance.run("Hello")
flow_instance.run("How can I help you?")
assert (flow_instance.get_response() == "Please provide more details.")
assert (
flow_instance.get_response() == "Please provide more details."
)
flow_instance.update_context("user_id", "54321")
flow_instance.run("I need help with my order")
assert (flow_instance.get_response() == "Sure, I can assist with that.")
assert (
flow_instance.get_response()
== "Sure, I can assist with that."
)
flow_instance.reset_conversation()
assert len(flow_instance.get_message_history()) == 0
assert flow_instance.get_context("user_id") is None
@ -1064,7 +1160,9 @@ def test_flow_concurrent_requests(flow_instance):
def test_flow_custom_timeout(flow_instance):
# Test custom timeout handling
flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds
flow_instance.set_timeout(
10
) # Set a custom timeout of 10 seconds
assert flow_instance.get_timeout() == 10
import time
@ -1115,10 +1213,16 @@ def test_flow_agent_history_prompt(flow_instance):
history = ["User: Hi", "AI: Hello"]
agent_history_prompt = flow_instance.agent_history_prompt(
system_prompt, history)
system_prompt, history
)
assert ("SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt)
assert ("History: ['User: Hi', 'AI: Hello']" in agent_history_prompt)
assert (
"SYSTEM_PROMPT: This is the system prompt."
in agent_history_prompt
)
assert (
"History: ['User: Hi', 'AI: Hello']" in agent_history_prompt
)
async def test_flow_run_concurrent(flow_instance):
@ -1133,18 +1237,9 @@ async def test_flow_run_concurrent(flow_instance):
def test_flow_bulk_run(flow_instance):
# Test bulk running of tasks
input_data = [
{
"task": "Task 1",
"param1": "value1"
},
{
"task": "Task 2",
"param2": "value2"
},
{
"task": "Task 3",
"param3": "value3"
},
{"task": "Task 1", "param1": "value1"},
{"task": "Task 2", "param2": "value2"},
{"task": "Task 3", "param3": "value3"},
]
responses = flow_instance.bulk_run(input_data)
@ -1159,7 +1254,9 @@ def test_flow_from_llm_and_template():
llm_instance = mocked_llm # Replace with your LLM class
template = "This is a template for testing."
flow_instance = Agent.from_llm_and_template(llm_instance, template)
flow_instance = Agent.from_llm_and_template(
llm_instance, template
)
assert isinstance(flow_instance, Agent)
@ -1168,10 +1265,12 @@ def test_flow_from_llm_and_template_file():
# Test creating Agent instance from an LLM and a template file
llm_instance = mocked_llm # Replace with your LLM class
template_file = ( # Create a template file for testing
"template.txt")
"template.txt"
)
flow_instance = Agent.from_llm_and_template_file(llm_instance,
template_file)
flow_instance = Agent.from_llm_and_template_file(
llm_instance, template_file
)
assert isinstance(flow_instance, Agent)

@ -44,7 +44,9 @@ def test_autoscaler_run():
agent.id,
"Generate a 10,000 word blog on health and wellness.",
)
assert (out == "Generate a 10,000 word blog on health and wellness.")
assert (
out == "Generate a 10,000 word blog on health and wellness."
)
def test_autoscaler_add_agent():
@ -237,7 +239,9 @@ def test_autoscaler_add_task():
def test_autoscaler_scale_up():
autoscaler = AutoScaler(initial_agents=5, scale_up_factor=2, agent=agent)
autoscaler = AutoScaler(
initial_agents=5, scale_up_factor=2, agent=agent
)
autoscaler.scale_up()
assert len(autoscaler.agents_pool) == 10

@ -7,7 +7,6 @@ from swarms.structs.base import BaseStructure
class TestBaseStructure:
def test_init(self):
base_structure = BaseStructure(
name="TestStructure",
@ -89,8 +88,11 @@ class TestBaseStructure:
with open(log_file) as file:
lines = file.readlines()
assert len(lines) == 1
assert (lines[0] == f"[{base_structure._current_timestamp()}]"
f" [{event_type}] {event}\n")
assert (
lines[0]
== f"[{base_structure._current_timestamp()}]"
f" [{event_type}] {event}\n"
)
@pytest.mark.asyncio
async def test_run_async(self):
@ -134,7 +136,9 @@ class TestBaseStructure:
artifact = {"key": "value"}
artifact_name = "test_artifact"
await base_structure.save_artifact_async(artifact, artifact_name)
await base_structure.save_artifact_async(
artifact, artifact_name
)
loaded_artifact = base_structure.load_artifact(artifact_name)
assert loaded_artifact == artifact
@ -147,8 +151,9 @@ class TestBaseStructure:
artifact = {"key": "value"}
artifact_name = "test_artifact"
base_structure.save_artifact(artifact, artifact_name)
loaded_artifact = await base_structure.load_artifact_async(artifact_name
)
loaded_artifact = await base_structure.load_artifact_async(
artifact_name
)
assert loaded_artifact == artifact
@ -165,8 +170,11 @@ class TestBaseStructure:
with open(log_file) as file:
lines = file.readlines()
assert len(lines) == 1
assert (lines[0] == f"[{base_structure._current_timestamp()}]"
f" [{event_type}] {event}\n")
assert (
lines[0]
== f"[{base_structure._current_timestamp()}]"
f" [{event_type}] {event}\n"
)
@pytest.mark.asyncio
async def test_asave_to_file(self, tmpdir):
@ -193,14 +201,18 @@ class TestBaseStructure:
def test_run_in_thread(self):
base_structure = BaseStructure()
result = base_structure.run_in_thread(lambda: "Thread Test Result")
result = base_structure.run_in_thread(
lambda: "Thread Test Result"
)
assert result.result() == "Thread Test Result"
def test_save_and_decompress_data(self):
base_structure = BaseStructure()
data = {"key": "value"}
compressed_data = base_structure.compress_data(data)
decompressed_data = base_structure.decompres_data(compressed_data)
decompressed_data = base_structure.decompres_data(
compressed_data
)
assert decompressed_data == data
def test_run_batched(self):
@ -210,11 +222,13 @@ class TestBaseStructure:
return f"Processed {data}"
batched_data = list(range(10))
result = base_structure.run_batched(batched_data,
batch_size=5,
func=run_function)
result = base_structure.run_batched(
batched_data, batch_size=5, func=run_function
)
expected_result = [f"Processed {data}" for data in batched_data]
expected_result = [
f"Processed {data}" for data in batched_data
]
assert result == expected_result
def test_load_config(self, tmpdir):
@ -232,12 +246,15 @@ class TestBaseStructure:
tmp_dir = tmpdir.mkdir("test_dir")
base_structure = BaseStructure()
data_to_backup = {"key": "value"}
base_structure.backup_data(data_to_backup, backup_path=tmp_dir)
base_structure.backup_data(
data_to_backup, backup_path=tmp_dir
)
backup_files = os.listdir(tmp_dir)
assert len(backup_files) == 1
loaded_data = base_structure.load_from_file(
os.path.join(tmp_dir, backup_files[0]))
os.path.join(tmp_dir, backup_files[0])
)
assert loaded_data == data_to_backup
def test_monitor_resources(self):
@ -262,9 +279,11 @@ class TestBaseStructure:
return f"Processed {data}"
batched_data = list(range(10))
result = base_structure.run_with_resources_batched(batched_data,
batch_size=5,
func=run_function)
result = base_structure.run_with_resources_batched(
batched_data, batch_size=5, func=run_function
)
expected_result = [f"Processed {data}" for data in batched_data]
expected_result = [
f"Processed {data}" for data in batched_data
]
assert result == expected_result

@ -30,8 +30,13 @@ def test_load_workflow_state():
workflow.load_workflow_state("workflow_state.json")
assert workflow.max_loops == 1
assert len(workflow.tasks) == 2
assert (workflow.tasks[0].description == "What's the weather in miami")
assert (workflow.tasks[1].description == "Create a report on these metrics")
assert (
workflow.tasks[0].description == "What's the weather in miami"
)
assert (
workflow.tasks[1].description
== "Create a report on these metrics"
)
teardown_workflow()

@ -18,7 +18,9 @@ def test_run():
workflow.add(task1)
workflow.add(task2)
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor:
with patch(
"concurrent.futures.ThreadPoolExecutor"
) as mock_executor:
future1 = Future()
future1.set_result(None)
future2 = Future()

@ -87,13 +87,16 @@ def test_return_history_as_string_with_different_roles(role, content):
@pytest.mark.parametrize("message_count", range(1, 11))
def test_return_history_as_string_with_multiple_messages(message_count,):
def test_return_history_as_string_with_multiple_messages(
message_count,
):
conv = Conversation()
for i in range(message_count):
conv.add("user", f"Message {i + 1}")
result = conv.return_history_as_string()
expected = "".join(
[f"user: Message {i + 1}\n\n" for i in range(message_count)])
[f"user: Message {i + 1}\n\n" for i in range(message_count)]
)
assert result == expected
@ -119,8 +122,10 @@ def test_return_history_as_string_with_large_message(conversation):
large_message = "Hello, world! " * 10000 # 10,000 repetitions
conversation.add("user", large_message)
result = conversation.return_history_as_string()
expected = ("user: Hello, world!\n\nassistant: Hello, user!\n\nuser:"
f" {large_message}\n\n")
expected = (
"user: Hello, world!\n\nassistant: Hello, user!\n\nuser:"
f" {large_message}\n\n"
)
assert result == expected
@ -136,8 +141,10 @@ def test_export_import_conversation(conversation, tmp_path):
conversation.export_conversation(filename)
new_conversation = Conversation()
new_conversation.import_conversation(filename)
assert (new_conversation.return_history_as_string() ==
conversation.return_history_as_string())
assert (
new_conversation.return_history_as_string()
== conversation.return_history_as_string()
)
def test_count_messages_by_role(conversation):

@ -11,7 +11,6 @@ llm2 = Anthropic()
# Mock the OpenAI class for testing
class MockOpenAI:
def __init__(self, *args, **kwargs):
pass
@ -140,9 +139,9 @@ def test_groupchat_manager_generate_reply():
selector = agent1
# Initialize GroupChatManager
manager = GroupChatManager(groupchat=groupchat,
selector=selector,
openai=mocked_openai)
manager = GroupChatManager(
groupchat=groupchat, selector=selector, openai=mocked_openai
)
# Generate a reply
task = "Write me a riddle"
@ -166,8 +165,9 @@ def test_groupchat_select_speaker():
# Simulate selecting the next speaker
last_speaker = agent1
next_speaker = manager.select_speaker(last_speaker=last_speaker,
selector=selector)
next_speaker = manager.select_speaker(
last_speaker=last_speaker, selector=selector
)
# Ensure the next speaker is agent2
assert next_speaker == agent2
@ -185,8 +185,9 @@ def test_groupchat_underpopulated_group():
# Simulate selecting the next speaker in an underpopulated group
last_speaker = agent1
next_speaker = manager.select_speaker(last_speaker=last_speaker,
selector=selector)
next_speaker = manager.select_speaker(
last_speaker=last_speaker, selector=selector
)
# Ensure the next speaker is the same as the last speaker in an underpopulated group
assert next_speaker == last_speaker
@ -204,13 +205,15 @@ def test_groupchat_max_rounds():
# Simulate the conversation with max rounds
last_speaker = agent1
for _ in range(2):
next_speaker = manager.select_speaker(last_speaker=last_speaker,
selector=selector)
next_speaker = manager.select_speaker(
last_speaker=last_speaker, selector=selector
)
last_speaker = next_speaker
# Try one more round, should stay with the last speaker
next_speaker = manager.select_speaker(last_speaker=last_speaker,
selector=selector)
next_speaker = manager.select_speaker(
last_speaker=last_speaker, selector=selector
)
# Ensure the next speaker is the same as the last speaker after reaching max rounds
assert next_speaker == last_speaker

@ -15,8 +15,10 @@ def valid_schema_path(tmp_path):
d = tmp_path / "sub"
d.mkdir()
p = d / "schema.json"
p.write_text('{"type": "object", "properties": {"name": {"type":'
' "string"}}}')
p.write_text(
'{"type": "object", "properties": {"name": {"type":'
' "string"}}}'
)
return str(p)
@ -31,7 +33,6 @@ def invalid_schema_path(tmp_path):
# This test class must be subclassed as JSON class is abstract
class TestableJSON(JSON):
def validate(self, data):
# Here must be a real validation implementation for testing
pass

@ -35,9 +35,15 @@ def test_majority_voting_run_concurrent(mocker):
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with("What is the capital of France?")
agent2.run.assert_called_once_with("What is the capital of France?")
agent3.run.assert_called_once_with("What is the capital of France?")
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
@ -77,9 +83,15 @@ def test_majority_voting_run_multithreaded(mocker):
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with("What is the capital of France?")
agent2.run.assert_called_once_with("What is the capital of France?")
agent3.run.assert_called_once_with("What is the capital of France?")
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
@ -121,9 +133,15 @@ async def test_majority_voting_run_asynchronous(mocker):
majority_vote = await mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with("What is the capital of France?")
agent2.run.assert_called_once_with("What is the capital of France?")
agent3.run.assert_called_once_with("What is the capital of France?")
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])

@ -8,7 +8,9 @@ def test_message_pool_initialization():
agent2 = Agent(llm=OpenAIChat(), agent_name="agent1")
moderator = Agent(llm=OpenAIChat(), agent_name="agent1")
agents = [agent1, agent2]
message_pool = MessagePool(agents=agents, moderator=moderator, turns=5)
message_pool = MessagePool(
agents=agents, moderator=moderator, turns=5
)
assert message_pool.agent == agents
assert message_pool.moderator == moderator
@ -18,21 +20,27 @@ def test_message_pool_initialization():
def test_message_pool_add():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5)
message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.messages == [{
"agent": agent1,
"content": "Hello, world!",
"turn": 1,
"visible_to": "all",
"logged": True,
}]
assert message_pool.messages == [
{
"agent": agent1,
"content": "Hello, world!",
"turn": 1,
"visible_to": "all",
"logged": True,
}
]
def test_message_pool_reset():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5)
message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
message_pool.reset()
@ -41,7 +49,9 @@ def test_message_pool_reset():
def test_message_pool_last_turn():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5)
message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.last_turn() == 1
@ -49,7 +59,9 @@ def test_message_pool_last_turn():
def test_message_pool_last_message():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5)
message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.last_message == {
@ -63,24 +75,28 @@ def test_message_pool_last_message():
def test_message_pool_get_all_messages():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5)
message_pool = MessagePool(
agents=[agent1], moderator=agent1, turns=5
)
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
assert message_pool.get_all_messages() == [{
"agent": agent1,
"content": "Hello, world!",
"turn": 1,
"visible_to": "all",
"logged": True,
}]
assert message_pool.get_all_messages() == [
{
"agent": agent1,
"content": "Hello, world!",
"turn": 1,
"visible_to": "all",
"logged": True,
}
]
def test_message_pool_get_visible_messages():
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
agent2 = Agent(agent_name="agent2")
message_pool = MessagePool(agents=[agent1, agent2],
moderator=agent1,
turns=5)
message_pool = MessagePool(
agents=[agent1, agent2], moderator=agent1, turns=5
)
message_pool.add(
agent=agent1,
content="Hello, agent2!",
@ -88,10 +104,14 @@ def test_message_pool_get_visible_messages():
visible_to=[agent2.agent_name],
)
assert message_pool.get_visible_messages(agent=agent2, turn=2) == [{
"agent": agent1,
"content": "Hello, agent2!",
"turn": 1,
"visible_to": [agent2.agent_name],
"logged": True,
}]
assert message_pool.get_visible_messages(
agent=agent2, turn=2
) == [
{
"agent": agent1,
"content": "Hello, agent2!",
"turn": 1,
"visible_to": [agent2.agent_name],
"logged": True,
}
]

@ -11,9 +11,7 @@ from swarms.structs.model_parallizer import ModelParallelizer
# Initialize the models
custom_config = {
"quantize": True,
"quantization_config": {
"load_in_4bit": True
},
"quantization_config": {"load_in_4bit": True},
"verbose": True,
}
huggingface_llm = HuggingfaceLLM(
@ -26,12 +24,14 @@ zeroscope_ttv = ZeroscopeTTV()
def test_init():
mp = ModelParallelizer([
huggingface_llm,
mixtral,
gpt4_vision_api,
zeroscope_ttv,
])
mp = ModelParallelizer(
[
huggingface_llm,
mixtral,
gpt4_vision_api,
zeroscope_ttv,
]
)
assert isinstance(mp, ModelParallelizer)
@ -39,20 +39,24 @@ def test_run():
mp = ModelParallelizer([huggingface_llm])
result = mp.run(
"Create a list of known biggest risks of structural collapse"
" with references")
" with references"
)
assert isinstance(result, str)
def test_run_all():
mp = ModelParallelizer([
huggingface_llm,
mixtral,
gpt4_vision_api,
zeroscope_ttv,
])
mp = ModelParallelizer(
[
huggingface_llm,
mixtral,
gpt4_vision_api,
zeroscope_ttv,
]
)
result = mp.run_all(
"Create a list of known biggest risks of structural collapse"
" with references")
" with references"
)
assert isinstance(result, list)
assert len(result) == 5
@ -71,8 +75,10 @@ def test_remove_llm():
def test_save_responses_to_file(tmp_path):
mp = ModelParallelizer([huggingface_llm])
mp.run("Create a list of known biggest risks of structural collapse"
" with references")
mp.run(
"Create a list of known biggest risks of structural collapse"
" with references"
)
file = tmp_path / "responses.txt"
mp.save_responses_to_file(file)
assert file.read_text() != ""
@ -80,8 +86,10 @@ def test_save_responses_to_file(tmp_path):
def test_get_task_history():
mp = ModelParallelizer([huggingface_llm])
mp.run("Create a list of known biggest risks of structural collapse"
" with references")
mp.run(
"Create a list of known biggest risks of structural collapse"
" with references"
)
assert mp.get_task_history() == [
"Create a list of known biggest risks of structural collapse"
" with references"
@ -90,8 +98,10 @@ def test_get_task_history():
def test_summary(capsys):
mp = ModelParallelizer([huggingface_llm])
mp.run("Create a list of known biggest risks of structural collapse"
" with references")
mp.run(
"Create a list of known biggest risks of structural collapse"
" with references"
)
mp.summary()
captured = capsys.readouterr()
assert "Tasks History:" in captured.out
@ -113,7 +123,8 @@ def test_concurrent_run():
mp = ModelParallelizer([huggingface_llm, mixtral])
result = mp.concurrent_run(
"Create a list of known biggest risks of structural collapse"
" with references")
" with references"
)
assert isinstance(result, list)
assert len(result) == 2

@ -73,8 +73,12 @@ def test_run(collaboration):
def test_format_results(collaboration):
collaboration.results = [{"agent": "Agent1", "response": "Response1"}]
formatted_results = collaboration.format_results(collaboration.results)
collaboration.results = [
{"agent": "Agent1", "response": "Response1"}
]
formatted_results = collaboration.format_results(
collaboration.results
)
assert "Agent1 responded: Response1" in formatted_results
@ -108,10 +112,7 @@ def test_repr(collaboration):
def test_load(collaboration):
state = {
"step": 5,
"results": [{
"agent": "Agent1",
"response": "Response1"
}],
"results": [{"agent": "Agent1", "response": "Response1"}],
}
with open(collaboration.saved_file_path_name, "w") as file:
json.dump(state, file)

@ -5,7 +5,6 @@ from swarms.structs import NonlinearWorkflow, Task
class TestNonlinearWorkflow:
def test_add_task(self):
llm = OpenAIChat(openai_api_key="")
task = Task(llm, "What's the weather in miami")
@ -34,7 +33,9 @@ class TestNonlinearWorkflow:
workflow = NonlinearWorkflow()
workflow.add(task1, task2.name)
workflow.add(task2, task1.name)
with pytest.raises(Exception, match="Circular dependency detected"):
with pytest.raises(
Exception, match="Circular dependency detected"
):
workflow.run()
def test_run_with_stopping_token(self):

@ -53,7 +53,9 @@ def test_run_stop_token_not_in_result():
try:
workflow.run()
except RecursionError:
pytest.fail("RecursiveWorkflow.run caused a RecursionError")
pytest.fail(
"RecursiveWorkflow.run caused a RecursionError"
)
assert agent.execute.call_count == max_iterations

@ -17,7 +17,6 @@ os.environ["OPENAI_API_KEY"] = "mocked_api_key"
# Mock OpenAIChat class for testing
class MockOpenAIChat:
def __init__(self, *args, **kwargs):
pass
@ -27,7 +26,6 @@ class MockOpenAIChat:
# Mock Agent class for testing
class MockAgent:
def __init__(self, *args, **kwargs):
pass
@ -37,7 +35,6 @@ class MockAgent:
# Mock SequentialWorkflow class for testing
class MockSequentialWorkflow:
def __init__(self, *args, **kwargs):
pass
@ -72,7 +69,10 @@ def test_sequential_workflow_initialization():
assert len(workflow.tasks) == 0
assert workflow.max_loops == 1
assert workflow.autosave is False
assert (workflow.saved_state_filepath == "sequential_workflow_state.json")
assert (
workflow.saved_state_filepath
== "sequential_workflow_state.json"
)
assert workflow.restore_state_filepath is None
assert workflow.dashboard is False
@ -177,7 +177,6 @@ def test_sequential_workflow_workflow_dashboard(capfd):
# Mock Agent class for async testing
class MockAsyncAgent:
def __init__(self, *args, **kwargs):
pass

@ -20,7 +20,9 @@ def test_swarm_network_init(swarm_network):
@patch("swarms.structs.swarm_net.SwarmNetwork.logger")
def test_run(mock_logger, swarm_network):
swarm_network.run()
assert (mock_logger.info.call_count == 10) # 2 log messages per agent
assert (
mock_logger.info.call_count == 10
) # 2 log messages per agent
def test_run_with_mocked_agents(mocker, swarm_network):

@ -7,7 +7,8 @@ from dotenv import load_dotenv
from swarms.models.gpt4_vision_api import GPT4VisionAPI
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,)
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
)
from swarms.structs.agent import Agent
from swarms.structs.task import Task
@ -20,11 +21,13 @@ def llm():
def test_agent_run_task(llm):
task = ("Analyze this image of an assembly line and identify any"
" issues such as misaligned parts, defects, or deviations"
" from the standard assembly process. IF there is anything"
" unsafe in the image, explain why it is unsafe and how it"
" could be improved.")
task = (
"Analyze this image of an assembly line and identify any"
" issues such as misaligned parts, defects, or deviations"
" from the standard assembly process. IF there is anything"
" unsafe in the image, explain why it is unsafe and how it"
" could be improved."
)
img = "assembly_line.jpg"
agent = Agent(
@ -46,7 +49,9 @@ def test_agent_run_task(llm):
@pytest.fixture
def task():
agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)]
return Task(id="Task_1", task="Task_Name", agents=agents, dependencies=[])
return Task(
id="Task_1", task="Task_Name", agents=agents, dependencies=[]
)
# Basic tests
@ -184,7 +189,9 @@ def test_task_execute_with_condition(mocker):
mock_agent = mocker.Mock(spec=Agent)
mock_agent.run.return_value = "result"
condition = mocker.Mock(return_value=True)
task = Task(description="Test task", agent=mock_agent, condition=condition)
task = Task(
description="Test task", agent=mock_agent, condition=condition
)
task.execute()
assert task.result == "result"
assert task.history == ["result"]
@ -194,7 +201,9 @@ def test_task_execute_with_condition_false(mocker):
mock_agent = mocker.Mock(spec=Agent)
mock_agent.run.return_value = "result"
condition = mocker.Mock(return_value=False)
task = Task(description="Test task", agent=mock_agent, condition=condition)
task = Task(
description="Test task", agent=mock_agent, condition=condition
)
task.execute()
assert task.result is None
assert task.history == []
@ -204,7 +213,9 @@ def test_task_execute_with_action(mocker):
mock_agent = mocker.Mock(spec=Agent)
mock_agent.run.return_value = "result"
action = mocker.Mock()
task = Task(description="Test task", agent=mock_agent, action=action)
task = Task(
description="Test task", agent=mock_agent, action=action
)
task.execute()
assert task.result == "result"
assert task.history == ["result"]
@ -232,9 +243,11 @@ def test_task_handle_scheduled_task_future(mocker):
agent=mock_agent,
schedule_time=datetime.now() + timedelta(days=1),
)
with mocker.patch.object(task.scheduler,
"enter") as mock_enter, mocker.patch.object(
task.scheduler, "run") as mock_run:
with mocker.patch.object(
task.scheduler, "enter"
) as mock_enter, mocker.patch.object(
task.scheduler, "run"
) as mock_run:
task.handle_scheduled_task()
mock_enter.assert_called_once()
mock_run.assert_called_once()

@ -21,9 +21,7 @@ def agent():
@pytest.fixture()
def concrete_task_queue():
class ConcreteTaskQueue(TaskQueueBase):
def add_task(self, task):
pass # Here you would add concrete implementation of add_task
@ -53,8 +51,9 @@ def test_add_task_failure(concrete_task_queue, task):
# Assuming the task is somehow invalid
# Note: Concrete implementation requires logic defining what an invalid task is
concrete_task_queue.add_task(task)
assert (concrete_task_queue.add_task(task)
is False) # Adding the same task again
assert (
concrete_task_queue.add_task(task) is False
) # Adding the same task again
@pytest.mark.parametrize("invalid_task", [None, "", {}, []])

@ -7,7 +7,6 @@ from swarms.structs.team import Team
class TestTeam(unittest.TestCase):
def setUp(self):
self.agent = Agent(
llm=OpenAIChat(openai_api_key=""),
@ -31,17 +30,16 @@ class TestTeam(unittest.TestCase):
with self.assertRaises(ValueError):
self.team.check_config(
{"config": json.dumps({
"agents": [],
"tasks": []
})})
{"config": json.dumps({"agents": [], "tasks": []})}
)
def test_run(self):
self.assertEqual(self.team.run(), self.task.execute())
def test_sequential_loop(self):
self.assertEqual(self.team._Team__sequential_loop(),
self.task.execute())
self.assertEqual(
self.team._Team__sequential_loop(), self.task.execute()
)
def test_log(self):
self.assertIsNone(self.team._Team__log("Test message"))

@ -27,7 +27,9 @@ def test_set_entry_point(graph_workflow):
def test_set_entry_point_nonexistent_node(graph_workflow):
with pytest.raises(ValueError, match="Node does not exist in graph"):
with pytest.raises(
ValueError, match="Node does not exist in graph"
):
graph_workflow.set_entry_point("nonexistent")
@ -40,23 +42,29 @@ def test_add_edge(graph_workflow):
def test_add_edge_nonexistent_node(graph_workflow):
graph_workflow.add("node1", "value1")
with pytest.raises(ValueError, match="Node does not exist in graph"):
with pytest.raises(
ValueError, match="Node does not exist in graph"
):
graph_workflow.add_edge("node1", "nonexistent")
def test_add_conditional_edges(graph_workflow):
graph_workflow.add("node1", "value1")
graph_workflow.add("node2", "value2")
graph_workflow.add_conditional_edges("node1", "condition1",
{"condition_value1": "node2"})
graph_workflow.add_conditional_edges(
"node1", "condition1", {"condition_value1": "node2"}
)
assert "node2" in graph_workflow.graph["node1"]["edges"]
def test_add_conditional_edges_nonexistent_node(graph_workflow):
graph_workflow.add("node1", "value1")
with pytest.raises(ValueError, match="Node does not exist in graph"):
with pytest.raises(
ValueError, match="Node does not exist in graph"
):
graph_workflow.add_conditional_edges(
"node1", "condition1", {"condition_value1": "nonexistent"})
"node1", "condition1", {"condition_value1": "nonexistent"}
)
def test_run(graph_workflow):

@ -35,8 +35,9 @@ def test_log_activity_posthog(mock_posthog, mock_env):
test_function()
# Check if the Posthog capture method was called with the expected arguments
mock_posthog.capture.assert_called_once_with("test_user_id", event_name,
event_properties)
mock_posthog.capture.assert_called_once_with(
"test_user_id", event_name, event_properties
)
# Test a scenario where environment variables are not set

@ -46,7 +46,9 @@ def test_generate_unique_identifier():
# Generate unique identifiers and ensure they are valid UUID strings
unique_id = generate_unique_identifier()
assert isinstance(unique_id, str)
assert uuid.UUID(unique_id, version=5, namespace=uuid.NAMESPACE_DNS)
assert uuid.UUID(
unique_id, version=5, namespace=uuid.NAMESPACE_DNS
)
def test_generate_user_id_edge_case():
@ -71,7 +73,9 @@ def test_get_system_info_edge_case():
# Test get_system_info for consistency
system_info1 = get_system_info()
system_info2 = get_system_info()
assert (system_info1 == system_info2) # Ensure system info remains the same
assert (
system_info1 == system_info2
) # Ensure system info remains the same
def test_generate_unique_identifier_edge_case():

@ -20,7 +20,9 @@ headers = {
def run_pytest():
result = subprocess.run(["pytest"], capture_output=True, text=True)
result = subprocess.run(
["pytest"], capture_output=True, text=True
)
return result.stdout + result.stderr
@ -54,7 +56,9 @@ def main():
errors = parse_pytest_output(pytest_output)
for error in errors:
issue_response = create_github_issue(error["title"], error["body"])
issue_response = create_github_issue(
error["title"], error["body"]
)
print(f"Issue created: {issue_response.get('html_url')}")

@ -16,8 +16,9 @@ def test_default_max_tokens():
assert tokenizer.default_max_tokens() == 100000
@pytest.mark.parametrize("model,tokens", [("claude-2.1", 200000),
("claude", 100000)])
@pytest.mark.parametrize(
"model,tokens", [("claude-2.1", 200000), ("claude", 100000)]
)
def test_default_max_tokens_models(model, tokens):
tokenizer = AnthropicTokenizer(model=model)
assert tokenizer.default_max_tokens() == tokens

@ -18,7 +18,9 @@ def test_post_init(base_tokenizer):
# 3. Tests for count_tokens_left with different inputs.
def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch):
def test_count_tokens_left_with_positive_diff(
base_tokenizer, monkeypatch
):
# Mocking count_tokens to return a specific value
monkeypatch.setattr(
"swarms.tokenizers.BaseTokenizer.count_tokens",
@ -27,7 +29,9 @@ def test_count_tokens_left_with_positive_diff(base_tokenizer, monkeypatch):
assert base_tokenizer.count_tokens_left("some text") == 50
def test_count_tokens_left_with_zero_diff(base_tokenizer, monkeypatch):
def test_count_tokens_left_with_zero_diff(
base_tokenizer, monkeypatch
):
monkeypatch.setattr(
"swarms.tokenizers.BaseTokenizer.count_tokens",
lambda x, y: 100,

@ -51,10 +51,18 @@ def test_prefix_space_tokens(hftokenizer):
# testing _maybe_add_prefix_space method
def test__maybe_add_prefix_space(hftokenizer):
assert (hftokenizer._maybe_add_prefix_space(
[101, 2003, 2010, 2050, 2001, 2339], " is why") == " is why")
assert (hftokenizer._maybe_add_prefix_space([2003, 2010, 2050, 2001, 2339],
"is why") == " is why")
assert (
hftokenizer._maybe_add_prefix_space(
[101, 2003, 2010, 2050, 2001, 2339], " is why"
)
== " is why"
)
assert (
hftokenizer._maybe_add_prefix_space(
[2003, 2010, 2050, 2001, 2339], "is why"
)
== " is why"
)
# continuing tests for other methods...

@ -18,21 +18,31 @@ def test_default_max_tokens(openai_tokenizer):
assert openai_tokenizer.default_max_tokens() == 4096
@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3),
(["Hello"], 4)])
@pytest.mark.parametrize(
"text, expected_output", [("Hello, world!", 3), (["Hello"], 4)]
)
def test_count_tokens_single(openai_tokenizer, text, expected_output):
assert (openai_tokenizer.count_tokens(text, "gpt-3") == expected_output)
assert (
openai_tokenizer.count_tokens(text, "gpt-3")
== expected_output
)
@pytest.mark.parametrize(
"texts, expected_output",
[(["Hello, world!", "This is a test"], 6), (["Hello"], 4)],
)
def test_count_tokens_multiple(openai_tokenizer, texts, expected_output):
assert (openai_tokenizer.count_tokens(texts, "gpt-3") == expected_output)
def test_count_tokens_multiple(
openai_tokenizer, texts, expected_output
):
assert (
openai_tokenizer.count_tokens(texts, "gpt-3")
== expected_output
)
@pytest.mark.parametrize("text, expected_output", [("Hello, world!", 3),
(["Hello"], 4)])
@pytest.mark.parametrize(
"text, expected_output", [("Hello, world!", 3), (["Hello"], 4)]
)
def test_len(openai_tokenizer, text, expected_output):
assert openai_tokenizer.len(text, "gpt-3") == expected_output

@ -7,7 +7,9 @@ from swarms.tokenizers.r_tokenizers import Tokenizer
def test_initializer_existing_model_file():
with patch("os.path.exists", return_value=True):
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model")
mock_model.assert_called_with("tokenizers/my_model.model")
assert tokenizer.model == mock_model.return_value
@ -15,43 +17,66 @@ def test_initializer_existing_model_file():
def test_initializer_model_folder():
with patch("os.path.exists", side_effect=[False, True]):
with patch("swarms.tokenizers.HuggingFaceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.HuggingFaceTokenizer"
) as mock_model:
tokenizer = Tokenizer("my_model_directory")
mock_model.assert_called_with("my_model_directory")
assert tokenizer.model == mock_model.return_value
def test_vocab_size():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.vocab_size == mock_model.return_value.vocab_size)
assert (
tokenizer.vocab_size == mock_model.return_value.vocab_size
)
def test_bos_token_id():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.bos_token_id == mock_model.return_value.bos_token_id)
assert (
tokenizer.bos_token_id
== mock_model.return_value.bos_token_id
)
def test_encode():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.encode("hello") ==
mock_model.return_value.encode.return_value)
assert (
tokenizer.encode("hello")
== mock_model.return_value.encode.return_value
)
def test_decode():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model")
assert (tokenizer.decode(
[1, 2, 3]) == mock_model.return_value.decode.return_value)
assert (
tokenizer.decode([1, 2, 3])
== mock_model.return_value.decode.return_value
)
def test_call():
with patch("swarms.tokenizers.SentencePieceTokenizer") as mock_model:
with patch(
"swarms.tokenizers.SentencePieceTokenizer"
) as mock_model:
tokenizer = Tokenizer("tokenizers/my_model.model")
assert (
tokenizer("hello") == mock_model.return_value.__call__.return_value)
tokenizer("hello")
== mock_model.return_value.__call__.return_value
)
# More tests can be added here

@ -65,20 +65,23 @@ def test_structured_tool_invoke():
def test_tool_creation():
tool = Tool(name="test_tool", func=lambda x: x, description="Test tool")
tool = Tool(
name="test_tool", func=lambda x: x, description="Test tool"
)
assert tool.name == "test_tool"
assert tool.func is not None
assert tool.description == "Test tool"
def test_tool_ainvoke():
tool = Tool(name="test_tool", func=lambda x: x, description="Test tool")
tool = Tool(
name="test_tool", func=lambda x: x, description="Test tool"
)
result = tool.ainvoke("input_data")
assert result == "input_data"
def test_tool_ainvoke_with_coroutine():
async def async_function(input_data):
return input_data
@ -92,7 +95,6 @@ def test_tool_ainvoke_with_coroutine():
def test_tool_args():
def sample_function(input_data):
return input_data
@ -108,7 +110,6 @@ def test_tool_args():
def test_structured_tool_creation():
class SampleArgsSchema:
pass
@ -125,7 +126,6 @@ def test_structured_tool_creation():
def test_structured_tool_ainvoke():
class SampleArgsSchema:
pass
@ -140,7 +140,6 @@ def test_structured_tool_ainvoke():
def test_structured_tool_ainvoke_with_coroutine():
class SampleArgsSchema:
pass
@ -158,7 +157,6 @@ def test_structured_tool_ainvoke_with_coroutine():
def test_structured_tool_args():
class SampleArgsSchema:
pass
@ -184,13 +182,14 @@ def test_tool_ainvoke_exception():
def test_tool_ainvoke_with_coroutine_exception():
tool = Tool(name="test_tool", coroutine=None, description="Test tool")
tool = Tool(
name="test_tool", coroutine=None, description="Test tool"
)
with pytest.raises(NotImplementedError):
tool.ainvoke("input_data")
def test_structured_tool_ainvoke_exception():
class SampleArgsSchema:
pass
@ -205,7 +204,6 @@ def test_structured_tool_ainvoke_exception():
def test_structured_tool_ainvoke_with_coroutine_exception():
class SampleArgsSchema:
pass
@ -227,7 +225,6 @@ def test_tool_description_not_provided():
def test_tool_invoke_with_callbacks():
def sample_function(input_data, callbacks=None):
if callbacks:
callbacks.on_start()
@ -243,7 +240,6 @@ def test_tool_invoke_with_callbacks():
def test_tool_invoke_with_new_argument():
def sample_function(input_data, callbacks=None):
return input_data
@ -253,7 +249,6 @@ def test_tool_invoke_with_new_argument():
def test_tool_ainvoke_with_new_argument():
async def async_function(input_data, callbacks=None):
return input_data
@ -263,7 +258,6 @@ def test_tool_ainvoke_with_new_argument():
def test_tool_description_from_docstring():
def sample_function(input_data):
"""Sample function docstring"""
return input_data
@ -273,7 +267,6 @@ def test_tool_description_from_docstring():
def test_tool_ainvoke_with_exceptions():
async def async_function(input_data):
raise ValueError("Test exception")
@ -286,7 +279,6 @@ def test_tool_ainvoke_with_exceptions():
def test_structured_tool_infer_schema_false():
def sample_function(input_data):
return input_data
@ -300,7 +292,6 @@ def test_structured_tool_infer_schema_false():
def test_structured_tool_ainvoke_with_callbacks():
class SampleArgsSchema:
pass
@ -316,14 +307,15 @@ def test_structured_tool_ainvoke_with_callbacks():
args_schema=SampleArgsSchema,
)
callbacks = MagicMock()
result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks)
result = tool.ainvoke(
{"tool_input": "input_data"}, callbacks=callbacks
)
assert result == "input_data"
callbacks.on_start.assert_called_once()
callbacks.on_finish.assert_called_once()
def test_structured_tool_description_not_provided():
class SampleArgsSchema:
pass
@ -338,7 +330,6 @@ def test_structured_tool_description_not_provided():
def test_structured_tool_args_schema():
class SampleArgsSchema:
pass
@ -354,7 +345,6 @@ def test_structured_tool_args_schema():
def test_structured_tool_args_schema_inference():
def sample_function(input_data):
return input_data
@ -368,7 +358,6 @@ def test_structured_tool_args_schema_inference():
def test_structured_tool_ainvoke_with_new_argument():
class SampleArgsSchema:
pass
@ -380,12 +369,13 @@ def test_structured_tool_ainvoke_with_new_argument():
func=sample_function,
args_schema=SampleArgsSchema,
)
result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None)
result = tool.ainvoke(
{"tool_input": "input_data"}, callbacks=None
)
assert result == "input_data"
def test_structured_tool_ainvoke_with_exceptions():
class SampleArgsSchema:
pass
@ -471,7 +461,9 @@ def test_tool_with_runnable(mock_runnable):
def test_tool_with_invalid_argument():
# Test passing an invalid argument type
with pytest.raises(ValueError):
tool(123) # Using an integer instead of a string/callable/Runnable
tool(
123
) # Using an integer instead of a string/callable/Runnable
def test_tool_with_multiple_arguments(mock_func):
@ -533,7 +525,9 @@ class MockSchema(BaseModel):
# Test suite starts here
class TestTool:
# Basic Functionality Tests
def test_tool_with_valid_callable_creates_base_tool(self, mock_func):
def test_tool_with_valid_callable_creates_base_tool(
self, mock_func
):
result = tool(mock_func)
assert isinstance(result, BaseTool)
@ -566,7 +560,6 @@ class TestTool:
# Error Handling Tests
def test_tool_raises_error_without_docstring(self):
def no_doc_func(arg: str) -> str:
return arg
@ -574,14 +567,14 @@ class TestTool:
tool(no_doc_func)
def test_tool_raises_error_runnable_without_object_schema(
self, mock_runnable):
self, mock_runnable
):
with pytest.raises(ValueError):
tool(mock_runnable)
# Decorator Behavior Tests
@pytest.mark.asyncio
async def test_async_tool_function(self):
@tool
async def async_func(arg: str) -> str:
return arg
@ -604,7 +597,6 @@ class TestTool:
pass
def test_tool_with_different_return_types(self):
@tool
def return_int(arg: str) -> int:
return int(arg)
@ -623,7 +615,6 @@ class TestTool:
# Test with multiple arguments
def test_tool_with_multiple_args(self):
@tool
def concat_strings(a: str, b: str) -> str:
return a + b
@ -633,7 +624,6 @@ class TestTool:
# Test handling of optional arguments
def test_tool_with_optional_args(self):
@tool
def greet(name: str, greeting: str = "Hello") -> str:
return f"{greeting} {name}"
@ -643,7 +633,6 @@ class TestTool:
# Test with variadic arguments
def test_tool_with_variadic_args(self):
@tool
def sum_numbers(*numbers: int) -> int:
return sum(numbers)
@ -653,7 +642,6 @@ class TestTool:
# Test with keyword arguments
def test_tool_with_kwargs(self):
@tool
def build_query(**kwargs) -> str:
return "&".join(f"{k}={v}" for k, v in kwargs.items())
@ -663,7 +651,6 @@ class TestTool:
# Test with mixed types of arguments
def test_tool_with_mixed_args(self):
@tool
def mixed_args(a: int, b: str, *args, **kwargs) -> str:
return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}"
@ -672,7 +659,6 @@ class TestTool:
# Test error handling with incorrect types
def test_tool_error_with_incorrect_types(self):
@tool
def add_numbers(a: int, b: int) -> int:
return a + b
@ -682,7 +668,6 @@ class TestTool:
# Test with nested tools
def test_nested_tools(self):
@tool
def inner_tool(arg: str) -> str:
return f"Inner {arg}"
@ -694,7 +679,6 @@ class TestTool:
assert outer_tool("Test") == "Outer Inner Test"
def test_tool_with_global_variable(self):
@tool
def access_global(arg: str) -> str:
return f"{global_var} {arg}"
@ -717,7 +701,6 @@ class TestTool:
# Test with complex data structures
def test_tool_with_complex_data_structures(self):
@tool
def process_data(data: dict) -> list:
return [data[key] for key in sorted(data.keys())]
@ -727,7 +710,6 @@ class TestTool:
# Test handling exceptions within the tool function
def test_tool_handling_internal_exceptions(self):
@tool
def function_that_raises(arg: str):
if arg == "error":
@ -740,7 +722,6 @@ class TestTool:
# Test with functions returning None
def test_tool_with_none_return(self):
@tool
def return_none(arg: str):
return None
@ -754,9 +735,7 @@ class TestTool:
# Test with class methods
def test_tool_with_class_method(self):
class MyClass:
@tool
def method(self, arg: str) -> str:
return f"Method {arg}"
@ -766,15 +745,12 @@ class TestTool:
# Test tool function with inheritance
def test_tool_with_inheritance(self):
class Parent:
@tool
def parent_method(self, arg: str) -> str:
return f"Parent {arg}"
class Child(Parent):
@tool
def child_method(self, arg: str) -> str:
return f"Child {arg}"
@ -785,9 +761,7 @@ class TestTool:
# Test with decorators stacking
def test_tool_with_multiple_decorators(self):
def another_decorator(func):
def wrapper(*args, **kwargs):
return f"Decorated {func(*args, **kwargs)}"
@ -813,7 +787,9 @@ class TestTool:
def thread_target():
results.append(threaded_function(5))
threads = [threading.Thread(target=thread_target) for _ in range(10)]
threads = [
threading.Thread(target=thread_target) for _ in range(10)
]
for t in threads:
t.start()
for t in threads:
@ -823,7 +799,6 @@ class TestTool:
# Test with recursive functions
def test_tool_with_recursive_function(self):
@tool
def recursive_function(n: int) -> int:
if n == 0:

@ -19,8 +19,9 @@ def test_check_device_no_cuda(monkeypatch):
def test_check_device_cuda_exception(monkeypatch):
# Mock torch.cuda.is_available to raise an exception
monkeypatch.setattr(torch.cuda, "is_available",
lambda: 1 / 0) # Raises ZeroDivisionError
monkeypatch.setattr(
torch.cuda, "is_available", lambda: 1 / 0
) # Raises ZeroDivisionError
result = check_device(log_level=logging.DEBUG)
assert result.type == "cpu"
@ -32,8 +33,12 @@ def test_check_device_one_cuda(monkeypatch):
# Mock torch.cuda.device_count to return 1
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
# Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0
monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0)
monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0)
monkeypatch.setattr(
torch.cuda, "memory_allocated", lambda device: 0
)
monkeypatch.setattr(
torch.cuda, "memory_reserved", lambda device: 0
)
result = check_device(log_level=logging.DEBUG)
assert len(result) == 1
@ -47,8 +52,12 @@ def test_check_device_multiple_cuda(monkeypatch):
# Mock torch.cuda.device_count to return 4
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
# Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0
monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0)
monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0)
monkeypatch.setattr(
torch.cuda, "memory_allocated", lambda device: 0
)
monkeypatch.setattr(
torch.cuda, "memory_reserved", lambda device: 0
)
result = check_device(log_level=logging.DEBUG)
assert len(result) == 4

@ -19,7 +19,8 @@ def test_print_class_parameters_agent():
# Replace with the expected output for Agent class
expected_output = (
"Parameter: name, Type: <class 'str'>\nParameter: age, Type:"
" <class 'int'>")
" <class 'int'>"
)
assert output == expected_output
@ -32,7 +33,9 @@ def test_print_class_parameters_error():
def get_parameters(class_name: str):
classes = {"Agent": Agent}
if class_name in classes:
return print_class_parameters(classes[class_name], api_format=True)
return print_class_parameters(
classes[class_name], api_format=True
)
else:
return {"error": "Class not found"}

@ -32,14 +32,18 @@ def test_multiple_gpus_available(mocker):
def test_device_properties(mocker):
mocker.patch("torch.cuda.is_available", return_value=True)
mocker.patch("torch.cuda.device_count", return_value=1)
mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5))
mocker.patch(
"torch.cuda.get_device_capability", return_value=(7, 5)
)
mocker.patch(
"torch.cuda.get_device_properties",
return_value=MagicMock(total_memory=1000),
)
mocker.patch("torch.cuda.memory_allocated", return_value=200)
mocker.patch("torch.cuda.memory_reserved", return_value=300)
mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80")
mocker.patch(
"torch.cuda.get_device_name", return_value="Tesla K80"
)
devices = check_device()
assert len(devices) == 1
assert str(devices[0]) == "cuda"
@ -48,21 +52,27 @@ def test_device_properties(mocker):
def test_memory_threshold(mocker):
mocker.patch("torch.cuda.is_available", return_value=True)
mocker.patch("torch.cuda.device_count", return_value=1)
mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5))
mocker.patch(
"torch.cuda.get_device_capability", return_value=(7, 5)
)
mocker.patch(
"torch.cuda.get_device_properties",
return_value=MagicMock(total_memory=1000),
)
mocker.patch("torch.cuda.memory_allocated",
return_value=900) # 90% of total memory
mocker.patch(
"torch.cuda.memory_allocated", return_value=900
) # 90% of total memory
mocker.patch("torch.cuda.memory_reserved", return_value=300)
mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80")
mocker.patch(
"torch.cuda.get_device_name", return_value="Tesla K80"
)
with pytest.warns(
UserWarning,
match=r"Memory usage for device cuda exceeds threshold",
UserWarning,
match=r"Memory usage for device cuda exceeds threshold",
):
devices = check_device(
memory_threshold=0.8) # Set memory threshold to 80%
memory_threshold=0.8
) # Set memory threshold to 80%
assert len(devices) == 1
assert str(devices[0]) == "cuda"
@ -70,21 +80,27 @@ def test_memory_threshold(mocker):
def test_compute_capability_threshold(mocker):
mocker.patch("torch.cuda.is_available", return_value=True)
mocker.patch("torch.cuda.device_count", return_value=1)
mocker.patch("torch.cuda.get_device_capability",
return_value=(3, 0)) # Compute capability 3.0
mocker.patch(
"torch.cuda.get_device_capability", return_value=(3, 0)
) # Compute capability 3.0
mocker.patch(
"torch.cuda.get_device_properties",
return_value=MagicMock(total_memory=1000),
)
mocker.patch("torch.cuda.memory_allocated", return_value=200)
mocker.patch("torch.cuda.memory_reserved", return_value=300)
mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80")
mocker.patch(
"torch.cuda.get_device_name", return_value="Tesla K80"
)
with pytest.warns(
UserWarning,
match=(r"Compute capability for device cuda is below threshold"),
UserWarning,
match=(
r"Compute capability for device cuda is below threshold"
),
):
devices = check_device(
capability_threshold=3.5) # Set compute capability threshold to 3.5
capability_threshold=3.5
) # Set compute capability threshold to 3.5
assert len(devices) == 1
assert str(devices[0]) == "cuda"

@ -14,7 +14,8 @@ def test_basic_message():
with mock.patch.object(Console, "print") as mock_print:
display_markdown_message("This is a test")
mock_print.assert_called_once_with(
Markdown("This is a test", style="cyan"))
Markdown("This is a test", style="cyan")
)
def test_empty_message():
@ -30,7 +31,8 @@ def test_colors(color):
with mock.patch.object(Console, "print") as mock_print:
display_markdown_message("This is a test", color)
mock_print.assert_called_once_with(
Markdown("This is a test", style=color))
Markdown("This is a test", style=color)
)
def test_dash_line():

@ -22,8 +22,12 @@ def markdown_content_without_code():
"""
def test_extract_code_from_markdown_with_code(markdown_content_with_code,):
extracted_code = extract_code_from_markdown(markdown_content_with_code)
def test_extract_code_from_markdown_with_code(
markdown_content_with_code,
):
extracted_code = extract_code_from_markdown(
markdown_content_with_code
)
assert "def my_func():" in extracted_code
assert 'print("This is my function.")' in extracted_code
assert "class MyClass:" in extracted_code
@ -31,8 +35,11 @@ def test_extract_code_from_markdown_with_code(markdown_content_with_code,):
def test_extract_code_from_markdown_without_code(
markdown_content_without_code,):
extracted_code = extract_code_from_markdown(markdown_content_without_code)
markdown_content_without_code,
):
extracted_code = extract_code_from_markdown(
markdown_content_without_code
)
assert extracted_code == ""

@ -8,8 +8,12 @@ from swarms.utils import find_image_path
def test_find_image_path_no_images():
assert (find_image_path("This is a test string without any image paths.")
is None)
assert (
find_image_path(
"This is a test string without any image paths."
)
is None
)
def test_find_image_path_one_image():
@ -19,7 +23,9 @@ def test_find_image_path_one_image():
def test_find_image_path_multiple_images():
text = "This string has two image paths: img1.png, and img2.jpg."
assert (find_image_path(text) == "img2.jpg") # Assuming both images exist
assert (
find_image_path(text) == "img2.jpg"
) # Assuming both images exist
def test_find_image_path_wrong_input():

@ -4,11 +4,14 @@ from swarms.utils import limit_tokens_from_string
def test_limit_tokens_from_string():
sentence = ("This is a test sentence. It is used for testing the number"
" of tokens.")
sentence = (
"This is a test sentence. It is used for testing the number"
" of tokens."
)
limited = limit_tokens_from_string(sentence, limit=5)
assert (len(limited.split())
<= 5), "The output string has more than 5 tokens."
assert (
len(limited.split()) <= 5
), "The output string has more than 5 tokens."
def test_limit_zero_tokens():
@ -18,21 +21,26 @@ def test_limit_zero_tokens():
def test_negative_token_limit():
sentence = ("This test will raise an exception when limit is negative.")
sentence = (
"This test will raise an exception when limit is negative."
)
with pytest.raises(Exception):
limit_tokens_from_string(sentence, limit=-1)
@pytest.mark.parametrize("sentence, model",
[("Some sentence", "unavailable-model")])
@pytest.mark.parametrize(
"sentence, model", [("Some sentence", "unavailable-model")]
)
def test_unknown_model(sentence, model):
with pytest.raises(Exception):
limit_tokens_from_string(sentence, model=model)
def test_string_token_limit_exceeded():
sentence = ("This is a long sentence with more than twenty tokens which"
" is used for testing. It checks whether the function"
" correctly limits the tokens to a specified amount.")
sentence = (
"This is a long sentence with more than twenty tokens which"
" is used for testing. It checks whether the function"
" correctly limits the tokens to a specified amount."
)
limited = limit_tokens_from_string(sentence, limit=20)
assert len(limited.split()) <= 20, "The token limit is exceeded."

@ -6,7 +6,6 @@ from swarms.utils import load_model_torch
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 2)
@ -26,7 +25,9 @@ def test_load_model_torch_success(tmp_path):
model_loaded = load_model_torch(model_path, model=DummyModel())
# Check if loaded model has the same architecture
assert isinstance(model_loaded, DummyModel), "Loaded model type mismatch."
assert isinstance(
model_loaded, DummyModel
), "Loaded model type mismatch."
# Test case 2: Test if function raises FileNotFoundError for non-existent file
@ -65,12 +66,13 @@ def test_load_model_torch_device_handling(tmp_path):
# Define a device other than default and load the model to the specified device
device = torch.device("cpu")
model_loaded = load_model_torch(model_path,
model=DummyModel(),
device=device)
model_loaded = load_model_torch(
model_path, model=DummyModel(), device=device
)
assert (model_loaded.fc.weight.device == device
), "Model not loaded to specified device."
assert (
model_loaded.fc.weight.device == device
), "Model not loaded to specified device."
# Test case 6: Testing for correct handling of '*args' and '**kwargs'
@ -80,14 +82,15 @@ def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path):
torch.save(model.state_dict(), model_path)
def mock_torch_load(*args, **kwargs):
assert ("pickle_module"
in kwargs), "Keyword arguments not passed to 'torch.load'."
assert (
"pickle_module" in kwargs
), "Keyword arguments not passed to 'torch.load'."
# Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly
monkeypatch.setattr(torch, "load", mock_torch_load)
load_model_torch(model_path,
model=DummyModel(),
pickle_module="dummy_module")
load_model_torch(
model_path, model=DummyModel(), pickle_module="dummy_module"
)
# Test case 7: Test for model loading on CPU if no GPU is available
@ -100,8 +103,9 @@ def test_load_model_torch_cpu(tmp_path):
return False
# Monkeypatch to simulate no GPU available
pytest.MonkeyPatch.setattr(torch.cuda, "is_available",
mock_torch_cuda_is_available)
pytest.MonkeyPatch.setattr(
torch.cuda, "is_available", mock_torch_cuda_is_available
)
model_loaded = load_model_torch(model_path, model=DummyModel())
# Ensure model is loaded on CPU

@ -42,13 +42,15 @@ def test_load_model_torch_model_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model)
mock_model.load_state_dict.assert_called_once_with({"key": "value"},
strict=True)
mock_model.load_state_dict.assert_called_once_with(
{"key": "value"}, strict=True
)
def test_load_model_torch_model_specified_strict_false(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model, strict=False)
mock_model.load_state_dict.assert_called_once_with({"key": "value"},
strict=False)
mock_model.load_state_dict.assert_called_once_with(
{"key": "value"}, strict=False
)

@ -18,7 +18,6 @@ def func2_with_exception(x):
def test_same_results_no_exception(caplog):
@math_eval(func1_no_exception, func2_no_exception)
def test_func(x):
return x
@ -29,7 +28,6 @@ def test_same_results_no_exception(caplog):
def test_func1_exception(caplog):
@math_eval(func1_with_exception, func2_no_exception)
def test_func(x):
return x

@ -10,7 +10,6 @@ from swarms.utils import metrics_decorator
# Basic successful test
def test_metrics_decorator_success():
@metrics_decorator
def decorated_func():
time.sleep(0.1)
@ -31,8 +30,8 @@ def test_metrics_decorator_success():
],
)
def test_metrics_decorator_with_various_wait_times_and_return_vals(
wait_time, return_val):
wait_time, return_val
):
@metrics_decorator
def decorated_func():
time.sleep(wait_time)
@ -56,17 +55,19 @@ def test_metrics_decorator_with_mocked_time(mocker):
return ["tok_1", "tok_2"]
metrics = decorated_func()
assert (metrics == """
assert (
metrics
== """
Time to First Token: 5
Generation Latency: 20
Throughput: 0.1
""")
"""
)
mocked_time.assert_any_call()
# Test to ensure that exceptions in the decorated function are propagated
def test_metrics_decorator_raises_exception():
@metrics_decorator
def decorated_func():
raise ValueError("Oops!")
@ -77,7 +78,6 @@ def test_metrics_decorator_raises_exception():
# Test to ensure proper handling when decorated function returns non-list value
def test_metrics_decorator_with_non_list_return_val():
@metrics_decorator
def decorated_func():
return "Hello, world!"

@ -29,8 +29,8 @@ def test_passing_non_pdf_file(tmpdir):
file = tmpdir.join("temp.txt")
file.write("This is a test")
with pytest.raises(
Exception,
match=r"An error occurred while reading the PDF file",
Exception,
match=r"An error occurred while reading the PDF file",
):
pdf_to_text(str(file))

@ -9,13 +9,16 @@ from swarms.utils import prep_torch_inference
def test_prep_torch_inference():
model_path = "model_path"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
model_mock = Mock()
model_mock.eval = Mock()
# Mocking the load_model_torch function to return our mock model.
with unittest.mock.patch("swarms.utils.load_model_torch",
return_value=model_mock) as _:
with unittest.mock.patch(
"swarms.utils.load_model_torch", return_value=model_mock
) as _:
model = prep_torch_inference(model_path, device)
# Check if model was properly loaded and eval function was called

@ -3,7 +3,8 @@ from unittest.mock import MagicMock
import torch
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,)
prep_torch_inference,
)
def test_prep_torch_inference_no_model_path():

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save