Merge pull request #1214 from hughiwnl/test-custom_agent

Created tests for custom_agent
pull/1237/head^2
Kye Gomez 2 days ago committed by GitHub
commit e111fbac82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,7 +6,6 @@ from swarms.structs.custom_agent import CustomAgent, AgentResponse
try: try:
import pytest_asyncio import pytest_asyncio
ASYNC_AVAILABLE = True ASYNC_AVAILABLE = True
except ImportError: except ImportError:
ASYNC_AVAILABLE = False ASYNC_AVAILABLE = False
@ -41,10 +40,7 @@ def test_custom_agent_initialization():
timeout=30.0, timeout=30.0,
verify_ssl=True, verify_ssl=True,
) )
assert ( assert custom_agent_instance.base_url == "https://api.example.com"
custom_agent_instance.base_url
== "https://api.example.com"
)
assert custom_agent_instance.endpoint == "v1/endpoint" assert custom_agent_instance.endpoint == "v1/endpoint"
assert custom_agent_instance.timeout == 30.0 assert custom_agent_instance.timeout == 30.0
assert custom_agent_instance.verify_ssl is True assert custom_agent_instance.verify_ssl is True
@ -55,9 +51,7 @@ def test_custom_agent_initialization():
raise raise
def test_custom_agent_initialization_with_default_headers( def test_custom_agent_initialization_with_default_headers(sample_custom_agent):
sample_custom_agent,
):
try: try:
custom_agent_no_headers = CustomAgent( custom_agent_no_headers = CustomAgent(
name="TestAgent", name="TestAgent",
@ -65,9 +59,7 @@ def test_custom_agent_initialization_with_default_headers(
base_url="https://api.test.com", base_url="https://api.test.com",
endpoint="test", endpoint="test",
) )
assert ( assert "Content-Type" in custom_agent_no_headers.default_headers
"Content-Type" in custom_agent_no_headers.default_headers
)
assert ( assert (
custom_agent_no_headers.default_headers["Content-Type"] custom_agent_no_headers.default_headers["Content-Type"]
== "application/json" == "application/json"
@ -86,10 +78,7 @@ def test_custom_agent_url_normalization():
base_url="https://api.test.com/", base_url="https://api.test.com/",
endpoint="/v1/test", endpoint="/v1/test",
) )
assert ( assert custom_agent_with_slashes.base_url == "https://api.test.com"
custom_agent_with_slashes.base_url
== "https://api.test.com"
)
assert custom_agent_with_slashes.endpoint == "v1/test" assert custom_agent_with_slashes.endpoint == "v1/test"
logger.debug("URL normalization works correctly") logger.debug("URL normalization works correctly")
except Exception as e: except Exception as e:
@ -101,22 +90,14 @@ def test_prepare_headers(sample_custom_agent):
try: try:
prepared_headers = sample_custom_agent._prepare_headers() prepared_headers = sample_custom_agent._prepare_headers()
assert "Authorization" in prepared_headers assert "Authorization" in prepared_headers
assert ( assert prepared_headers["Authorization"] == "Bearer test-token"
prepared_headers["Authorization"] == "Bearer test-token"
)
additional_headers = {"X-Custom-Header": "custom-value"} additional_headers = {"X-Custom-Header": "custom-value"}
prepared_headers_with_additional = ( prepared_headers_with_additional = (
sample_custom_agent._prepare_headers(additional_headers) sample_custom_agent._prepare_headers(additional_headers)
) )
assert ( assert prepared_headers_with_additional["X-Custom-Header"] == "custom-value"
prepared_headers_with_additional["X-Custom-Header"] assert prepared_headers_with_additional["Authorization"] == "Bearer test-token"
== "custom-value"
)
assert (
prepared_headers_with_additional["Authorization"]
== "Bearer test-token"
)
logger.debug("Header preparation works correctly") logger.debug("Header preparation works correctly")
except Exception as e: except Exception as e:
logger.error(f"Failed to test prepare_headers: {e}") logger.error(f"Failed to test prepare_headers: {e}")
@ -126,9 +107,7 @@ def test_prepare_headers(sample_custom_agent):
def test_prepare_payload_dict(sample_custom_agent): def test_prepare_payload_dict(sample_custom_agent):
try: try:
payload_dict = {"key": "value", "number": 123} payload_dict = {"key": "value", "number": 123}
prepared_payload = sample_custom_agent._prepare_payload( prepared_payload = sample_custom_agent._prepare_payload(payload_dict)
payload_dict
)
assert isinstance(prepared_payload, str) assert isinstance(prepared_payload, str)
parsed = json.loads(prepared_payload) parsed = json.loads(prepared_payload)
assert parsed["key"] == "value" assert parsed["key"] == "value"
@ -142,30 +121,22 @@ def test_prepare_payload_dict(sample_custom_agent):
def test_prepare_payload_string(sample_custom_agent): def test_prepare_payload_string(sample_custom_agent):
try: try:
payload_string = '{"test": "value"}' payload_string = '{"test": "value"}'
prepared_payload = sample_custom_agent._prepare_payload( prepared_payload = sample_custom_agent._prepare_payload(payload_string)
payload_string
)
assert prepared_payload == payload_string assert prepared_payload == payload_string
logger.debug("String payload prepared correctly") logger.debug("String payload prepared correctly")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to test prepare_payload with string: {e}")
f"Failed to test prepare_payload with string: {e}"
)
raise raise
def test_prepare_payload_bytes(sample_custom_agent): def test_prepare_payload_bytes(sample_custom_agent):
try: try:
payload_bytes = b'{"test": "value"}' payload_bytes = b'{"test": "value"}'
prepared_payload = sample_custom_agent._prepare_payload( prepared_payload = sample_custom_agent._prepare_payload(payload_bytes)
payload_bytes
)
assert prepared_payload == payload_bytes assert prepared_payload == payload_bytes
logger.debug("Bytes payload prepared correctly") logger.debug("Bytes payload prepared correctly")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to test prepare_payload with bytes: {e}")
f"Failed to test prepare_payload with bytes: {e}"
)
raise raise
@ -177,9 +148,7 @@ def test_parse_response_success(sample_custom_agent):
mock_response.headers = {"content-type": "application/json"} mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {"message": "success"} mock_response.json.return_value = {"message": "success"}
parsed_response = sample_custom_agent._parse_response( parsed_response = sample_custom_agent._parse_response(mock_response)
mock_response
)
assert isinstance(parsed_response, AgentResponse) assert isinstance(parsed_response, AgentResponse)
assert parsed_response.status_code == 200 assert parsed_response.status_code == 200
assert parsed_response.success is True assert parsed_response.success is True
@ -198,9 +167,7 @@ def test_parse_response_error(sample_custom_agent):
mock_response.text = "Not Found" mock_response.text = "Not Found"
mock_response.headers = {"content-type": "text/plain"} mock_response.headers = {"content-type": "text/plain"}
parsed_response = sample_custom_agent._parse_response( parsed_response = sample_custom_agent._parse_response(mock_response)
mock_response
)
assert isinstance(parsed_response, AgentResponse) assert isinstance(parsed_response, AgentResponse)
assert parsed_response.status_code == 404 assert parsed_response.status_code == 404
assert parsed_response.success is False assert parsed_response.success is False
@ -222,15 +189,11 @@ def test_extract_content_openai_format(sample_custom_agent):
} }
] ]
} }
extracted_content = sample_custom_agent._extract_content( extracted_content = sample_custom_agent._extract_content(openai_response)
openai_response
)
assert extracted_content == "This is the response content" assert extracted_content == "This is the response content"
logger.debug("OpenAI format content extracted correctly") logger.debug("OpenAI format content extracted correctly")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to test extract_content OpenAI format: {e}")
f"Failed to test extract_content OpenAI format: {e}"
)
raise raise
@ -239,33 +202,25 @@ def test_extract_content_anthropic_format(sample_custom_agent):
anthropic_response = { anthropic_response = {
"content": [ "content": [
{"text": "First part "}, {"text": "First part "},
{"text": "second part"}, {"text": "second part"}
] ]
} }
extracted_content = sample_custom_agent._extract_content( extracted_content = sample_custom_agent._extract_content(anthropic_response)
anthropic_response
)
assert extracted_content == "First part second part" assert extracted_content == "First part second part"
logger.debug("Anthropic format content extracted correctly") logger.debug("Anthropic format content extracted correctly")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to test extract_content Anthropic format: {e}")
f"Failed to test extract_content Anthropic format: {e}"
)
raise raise
def test_extract_content_generic_format(sample_custom_agent): def test_extract_content_generic_format(sample_custom_agent):
try: try:
generic_response = {"text": "Generic response text"} generic_response = {"text": "Generic response text"}
extracted_content = sample_custom_agent._extract_content( extracted_content = sample_custom_agent._extract_content(generic_response)
generic_response
)
assert extracted_content == "Generic response text" assert extracted_content == "Generic response text"
logger.debug("Generic format content extracted correctly") logger.debug("Generic format content extracted correctly")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to test extract_content generic format: {e}")
f"Failed to test extract_content generic format: {e}"
)
raise raise
@ -274,18 +229,14 @@ def test_run_success(mock_client_class, sample_custom_agent):
try: try:
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.text = ( mock_response.text = '{"choices": [{"message": {"content": "Success"}}]}'
'{"choices": [{"message": {"content": "Success"}}]}'
)
mock_response.json.return_value = { mock_response.json.return_value = {
"choices": [{"message": {"content": "Success"}}] "choices": [{"message": {"content": "Success"}}]
} }
mock_response.headers = {"content-type": "application/json"} mock_response.headers = {"content-type": "application/json"}
mock_client_instance = Mock() mock_client_instance = Mock()
mock_client_instance.__enter__ = Mock( mock_client_instance.__enter__ = Mock(return_value=mock_client_instance)
return_value=mock_client_instance
)
mock_client_instance.__exit__ = Mock(return_value=None) mock_client_instance.__exit__ = Mock(return_value=None)
mock_client_instance.post.return_value = mock_response mock_client_instance.post.return_value = mock_response
mock_client_class.return_value = mock_client_instance mock_client_class.return_value = mock_client_instance
@ -308,9 +259,7 @@ def test_run_error_response(mock_client_class, sample_custom_agent):
mock_response.text = "Internal Server Error" mock_response.text = "Internal Server Error"
mock_client_instance = Mock() mock_client_instance = Mock()
mock_client_instance.__enter__ = Mock( mock_client_instance.__enter__ = Mock(return_value=mock_client_instance)
return_value=mock_client_instance
)
mock_client_instance.__exit__ = Mock(return_value=None) mock_client_instance.__exit__ = Mock(return_value=None)
mock_client_instance.post.return_value = mock_response mock_client_instance.post.return_value = mock_response
mock_client_class.return_value = mock_client_instance mock_client_class.return_value = mock_client_instance
@ -331,13 +280,9 @@ def test_run_request_error(mock_client_class, sample_custom_agent):
import httpx import httpx
mock_client_instance = Mock() mock_client_instance = Mock()
mock_client_instance.__enter__ = Mock( mock_client_instance.__enter__ = Mock(return_value=mock_client_instance)
return_value=mock_client_instance
)
mock_client_instance.__exit__ = Mock(return_value=None) mock_client_instance.__exit__ = Mock(return_value=None)
mock_client_instance.post.side_effect = httpx.RequestError( mock_client_instance.post.side_effect = httpx.RequestError("Connection failed")
"Connection failed"
)
mock_client_class.return_value = mock_client_instance mock_client_class.return_value = mock_client_instance
test_payload = {"message": "test"} test_payload = {"message": "test"}
@ -350,33 +295,23 @@ def test_run_request_error(mock_client_class, sample_custom_agent):
raise raise
@pytest.mark.skipif( @pytest.mark.skipif(not ASYNC_AVAILABLE, reason="pytest-asyncio not installed")
not ASYNC_AVAILABLE, reason="pytest-asyncio not installed"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("swarms.structs.custom_agent.httpx.AsyncClient") @patch("swarms.structs.custom_agent.httpx.AsyncClient")
async def test_run_async_success( async def test_run_async_success(mock_async_client_class, sample_custom_agent):
mock_async_client_class, sample_custom_agent
):
try: try:
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.text = ( mock_response.text = '{"content": [{"text": "Async Success"}]}'
'{"content": [{"text": "Async Success"}]}'
)
mock_response.json.return_value = { mock_response.json.return_value = {
"content": [{"text": "Async Success"}] "content": [{"text": "Async Success"}]
} }
mock_response.headers = {"content-type": "application/json"} mock_response.headers = {"content-type": "application/json"}
mock_client_instance = AsyncMock() mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock( mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
return_value=mock_client_instance
)
mock_client_instance.__aexit__ = AsyncMock(return_value=None) mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_instance.post = AsyncMock( mock_client_instance.post = AsyncMock(return_value=mock_response)
return_value=mock_response
)
mock_async_client_class.return_value = mock_client_instance mock_async_client_class.return_value = mock_client_instance
test_payload = {"message": "test"} test_payload = {"message": "test"}
@ -389,27 +324,19 @@ async def test_run_async_success(
raise raise
@pytest.mark.skipif( @pytest.mark.skipif(not ASYNC_AVAILABLE, reason="pytest-asyncio not installed")
not ASYNC_AVAILABLE, reason="pytest-asyncio not installed"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("swarms.structs.custom_agent.httpx.AsyncClient") @patch("swarms.structs.custom_agent.httpx.AsyncClient")
async def test_run_async_error_response( async def test_run_async_error_response(mock_async_client_class, sample_custom_agent):
mock_async_client_class, sample_custom_agent
):
try: try:
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 400 mock_response.status_code = 400
mock_response.text = "Bad Request" mock_response.text = "Bad Request"
mock_client_instance = AsyncMock() mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock( mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
return_value=mock_client_instance
)
mock_client_instance.__aexit__ = AsyncMock(return_value=None) mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_instance.post = AsyncMock( mock_client_instance.post = AsyncMock(return_value=mock_response)
return_value=mock_response
)
mock_async_client_class.return_value = mock_client_instance mock_async_client_class.return_value = mock_client_instance
test_payload = {"message": "test"} test_payload = {"message": "test"}
@ -439,4 +366,4 @@ def test_agent_response_dataclass():
logger.debug("AgentResponse dataclass created correctly") logger.debug("AgentResponse dataclass created correctly")
except Exception as e: except Exception as e:
logger.error(f"Failed to test AgentResponse dataclass: {e}") logger.error(f"Failed to test AgentResponse dataclass: {e}")
raise raise
Loading…
Cancel
Save