parent
c8f1d82f85
commit
30c140fbaa
@ -1,40 +0,0 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.structs.custom_agent import CustomAgent
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Example usage with Anthropic API
|
||||
if __name__ == "__main__":
|
||||
# Initialize the agent for Anthropic API
|
||||
anthropic_agent = CustomAgent(
|
||||
base_url="https://api.anthropic.com",
|
||||
endpoint="v1/messages",
|
||||
headers={
|
||||
"x-api-key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic-version": "2023-06-01",
|
||||
},
|
||||
)
|
||||
|
||||
# Example payload for Anthropic API
|
||||
payload = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"max_tokens": 1000,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! Can you explain what artaddificial intelligence is?",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Make the request
|
||||
try:
|
||||
response = anthropic_agent.run(payload)
|
||||
print("Anthropic API Response:")
|
||||
print(response)
|
||||
print(type(response))
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
@ -1,344 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
"""Data class to hold agent response information"""
|
||||
|
||||
status_code: int
|
||||
content: str
|
||||
headers: Dict[str, str]
|
||||
json_data: Optional[Dict[str, Any]] = None
|
||||
success: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class CustomAgent:
|
||||
"""
|
||||
A custom HTTP agent class for making POST requests using httpx.
|
||||
|
||||
Features:
|
||||
- Configurable headers and payload
|
||||
- Both sync and async execution
|
||||
- Built-in error handling and logging
|
||||
- Flexible response handling
|
||||
- Name and description
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
base_url: str,
|
||||
endpoint: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 30.0,
|
||||
verify_ssl: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Custom Agent.
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the API endpoint
|
||||
endpoint: API endpoint path
|
||||
headers: Default headers to include in requests
|
||||
timeout: Request timeout in seconds
|
||||
verify_ssl: Whether to verify SSL certificates
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.endpoint = endpoint.lstrip("/")
|
||||
self.default_headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
|
||||
# Default headers
|
||||
if "Content-Type" not in self.default_headers:
|
||||
self.default_headers["Content-Type"] = "application/json"
|
||||
|
||||
logger.info(
|
||||
f"CustomAgent initialized for {self.base_url}/{self.endpoint}"
|
||||
)
|
||||
|
||||
def _prepare_headers(
|
||||
self, additional_headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Merge default headers with additional headers."""
|
||||
headers = self.default_headers.copy()
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
return headers
|
||||
|
||||
def _prepare_payload(
|
||||
self, payload: Union[Dict, str, bytes]
|
||||
) -> Union[str, bytes]:
|
||||
"""Prepare the payload for the request."""
|
||||
if isinstance(payload, dict):
|
||||
return json.dumps(payload)
|
||||
return payload
|
||||
|
||||
def _parse_response(
|
||||
self, response: httpx.Response
|
||||
) -> AgentResponse:
|
||||
"""Parse httpx response into AgentResponse object."""
|
||||
try:
|
||||
# Try to parse JSON if possible
|
||||
json_data = None
|
||||
if response.headers.get("content-type", "").startswith(
|
||||
"application/json"
|
||||
):
|
||||
try:
|
||||
json_data = response.json()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return AgentResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.text,
|
||||
headers=dict(response.headers),
|
||||
json_data=json_data,
|
||||
success=200 <= response.status_code < 300,
|
||||
error_message=(
|
||||
None
|
||||
if 200 <= response.status_code < 300
|
||||
else f"HTTP {response.status_code}"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing response: {e}")
|
||||
return AgentResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.text,
|
||||
headers=dict(response.headers),
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
def _extract_content(self, response_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract message content from API response, supporting multiple formats.
|
||||
|
||||
Args:
|
||||
response_data: Parsed JSON response from API
|
||||
|
||||
Returns:
|
||||
str: Extracted message content
|
||||
"""
|
||||
try:
|
||||
# OpenAI format
|
||||
if (
|
||||
"choices" in response_data
|
||||
and response_data["choices"]
|
||||
):
|
||||
choice = response_data["choices"][0]
|
||||
if (
|
||||
"message" in choice
|
||||
and "content" in choice["message"]
|
||||
):
|
||||
return choice["message"]["content"]
|
||||
elif "text" in choice:
|
||||
return choice["text"]
|
||||
|
||||
# Anthropic format
|
||||
elif (
|
||||
"content" in response_data
|
||||
and response_data["content"]
|
||||
):
|
||||
if isinstance(response_data["content"], list):
|
||||
# Extract text from content blocks
|
||||
text_parts = []
|
||||
for content_block in response_data["content"]:
|
||||
if (
|
||||
isinstance(content_block, dict)
|
||||
and "text" in content_block
|
||||
):
|
||||
text_parts.append(content_block["text"])
|
||||
elif isinstance(content_block, str):
|
||||
text_parts.append(content_block)
|
||||
return "".join(text_parts)
|
||||
elif isinstance(response_data["content"], str):
|
||||
return response_data["content"]
|
||||
|
||||
# Generic fallback - look for common content fields
|
||||
elif "text" in response_data:
|
||||
return response_data["text"]
|
||||
elif "message" in response_data:
|
||||
return response_data["message"]
|
||||
elif "response" in response_data:
|
||||
return response_data["response"]
|
||||
|
||||
# If no known format, return the entire response as JSON string
|
||||
logger.warning(
|
||||
"Unknown response format, returning full response"
|
||||
)
|
||||
return json.dumps(response_data, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}")
|
||||
return json.dumps(response_data, indent=2)
|
||||
|
||||
def run(
|
||||
self,
|
||||
payload: Union[Dict[str, Any], str, bytes],
|
||||
additional_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Execute a synchronous POST request.
|
||||
|
||||
Args:
|
||||
payload: Request body/payload
|
||||
additional_headers: Additional headers for this request
|
||||
**kwargs: Additional httpx client options
|
||||
|
||||
Returns:
|
||||
str: Extracted message content from response
|
||||
"""
|
||||
url = f"{self.base_url}/{self.endpoint}"
|
||||
request_headers = self._prepare_headers(additional_headers)
|
||||
request_payload = self._prepare_payload(payload)
|
||||
|
||||
logger.info(f"Making POST request to: {url}")
|
||||
|
||||
try:
|
||||
with httpx.Client(
|
||||
timeout=self.timeout, verify=self.verify_ssl, **kwargs
|
||||
) as client:
|
||||
response = client.post(
|
||||
url,
|
||||
content=request_payload,
|
||||
headers=request_headers,
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info(
|
||||
f"Request successful: {response.status_code}"
|
||||
)
|
||||
try:
|
||||
response_data = response.json()
|
||||
return self._extract_content(response_data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Response is not JSON, returning raw text"
|
||||
)
|
||||
return response.text
|
||||
else:
|
||||
logger.warning(
|
||||
f"Request failed: {response.status_code}"
|
||||
)
|
||||
return f"Error: HTTP {response.status_code} - {response.text}"
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
return f"Request error: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
return f"Unexpected error: {str(e)}"
|
||||
|
||||
async def run_async(
|
||||
self,
|
||||
payload: Union[Dict[str, Any], str, bytes],
|
||||
additional_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Execute an asynchronous POST request.
|
||||
|
||||
Args:
|
||||
payload: Request body/payload
|
||||
additional_headers: Additional headers for this request
|
||||
**kwargs: Additional httpx client options
|
||||
|
||||
Returns:
|
||||
str: Extracted message content from response
|
||||
"""
|
||||
url = f"{self.base_url}/{self.endpoint}"
|
||||
request_headers = self._prepare_headers(additional_headers)
|
||||
request_payload = self._prepare_payload(payload)
|
||||
|
||||
logger.info(f"Making async POST request to: {url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self.timeout, verify=self.verify_ssl, **kwargs
|
||||
) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
content=request_payload,
|
||||
headers=request_headers,
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info(
|
||||
f"Async request successful: {response.status_code}"
|
||||
)
|
||||
try:
|
||||
response_data = response.json()
|
||||
return self._extract_content(response_data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Async response is not JSON, returning raw text"
|
||||
)
|
||||
return response.text
|
||||
else:
|
||||
logger.warning(
|
||||
f"Async request failed: {response.status_code}"
|
||||
)
|
||||
return f"Error: HTTP {response.status_code} - {response.text}"
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Async request error: {e}")
|
||||
return f"Request error: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected async error: {e}")
|
||||
return f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
# # Example usage with Anthropic API
|
||||
# if __name__ == "__main__":
|
||||
# # Initialize the agent for Anthropic API
|
||||
# anthropic_agent = CustomAgent(
|
||||
# base_url="https://api.anthropic.com",
|
||||
# endpoint="v1/messages",
|
||||
# headers={
|
||||
# "x-api-key": "your-anthropic-api-key-here",
|
||||
# "anthropic-version": "2023-06-01"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # Example payload for Anthropic API
|
||||
# payload = {
|
||||
# "model": "claude-3-sonnet-20240229",
|
||||
# "max_tokens": 1000,
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "Hello! Can you explain what artificial intelligence is?"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
|
||||
# # Make the request
|
||||
# try:
|
||||
# response = anthropic_agent.run(payload)
|
||||
# print("Anthropic API Response:")
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# print(f"Error: {e}")
|
||||
|
||||
# # Example with async usage
|
||||
# # import asyncio
|
||||
# #
|
||||
# # async def async_example():
|
||||
# # response = await anthropic_agent.run_async(payload)
|
||||
# # print("Async Anthropic API Response:")
|
||||
# # print(response)
|
||||
# #
|
||||
# # Uncomment to run async example
|
||||
# # asyncio.run(async_example())
|
||||
@ -1,442 +0,0 @@
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from loguru import logger
|
||||
from swarms.structs.custom_agent import CustomAgent, AgentResponse
|
||||
|
||||
try:
|
||||
import pytest_asyncio
|
||||
|
||||
ASYNC_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASYNC_AVAILABLE = False
|
||||
pytest_asyncio = None
|
||||
|
||||
|
||||
def create_test_custom_agent():
|
||||
return CustomAgent(
|
||||
name="TestAgent",
|
||||
description="Test agent for unit testing",
|
||||
base_url="https://api.test.com",
|
||||
endpoint="v1/test",
|
||||
headers={"Authorization": "Bearer test-token"},
|
||||
timeout=10.0,
|
||||
verify_ssl=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_custom_agent():
|
||||
return create_test_custom_agent()
|
||||
|
||||
|
||||
def test_custom_agent_initialization():
|
||||
try:
|
||||
custom_agent_instance = CustomAgent(
|
||||
name="TestAgent",
|
||||
description="Test description",
|
||||
base_url="https://api.example.com",
|
||||
endpoint="v1/endpoint",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30.0,
|
||||
verify_ssl=True,
|
||||
)
|
||||
assert (
|
||||
custom_agent_instance.base_url
|
||||
== "https://api.example.com"
|
||||
)
|
||||
assert custom_agent_instance.endpoint == "v1/endpoint"
|
||||
assert custom_agent_instance.timeout == 30.0
|
||||
assert custom_agent_instance.verify_ssl is True
|
||||
assert "Content-Type" in custom_agent_instance.default_headers
|
||||
logger.info("CustomAgent initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CustomAgent: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_custom_agent_initialization_with_default_headers(
|
||||
sample_custom_agent,
|
||||
):
|
||||
try:
|
||||
custom_agent_no_headers = CustomAgent(
|
||||
name="TestAgent",
|
||||
description="Test",
|
||||
base_url="https://api.test.com",
|
||||
endpoint="test",
|
||||
)
|
||||
assert (
|
||||
"Content-Type" in custom_agent_no_headers.default_headers
|
||||
)
|
||||
assert (
|
||||
custom_agent_no_headers.default_headers["Content-Type"]
|
||||
== "application/json"
|
||||
)
|
||||
logger.debug("Default Content-Type header added correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test default headers: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_custom_agent_url_normalization():
|
||||
try:
|
||||
custom_agent_with_slashes = CustomAgent(
|
||||
name="TestAgent",
|
||||
description="Test",
|
||||
base_url="https://api.test.com/",
|
||||
endpoint="/v1/test",
|
||||
)
|
||||
assert (
|
||||
custom_agent_with_slashes.base_url
|
||||
== "https://api.test.com"
|
||||
)
|
||||
assert custom_agent_with_slashes.endpoint == "v1/test"
|
||||
logger.debug("URL normalization works correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test URL normalization: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_prepare_headers(sample_custom_agent):
|
||||
try:
|
||||
prepared_headers = sample_custom_agent._prepare_headers()
|
||||
assert "Authorization" in prepared_headers
|
||||
assert (
|
||||
prepared_headers["Authorization"] == "Bearer test-token"
|
||||
)
|
||||
|
||||
additional_headers = {"X-Custom-Header": "custom-value"}
|
||||
prepared_headers_with_additional = (
|
||||
sample_custom_agent._prepare_headers(additional_headers)
|
||||
)
|
||||
assert (
|
||||
prepared_headers_with_additional["X-Custom-Header"]
|
||||
== "custom-value"
|
||||
)
|
||||
assert (
|
||||
prepared_headers_with_additional["Authorization"]
|
||||
== "Bearer test-token"
|
||||
)
|
||||
logger.debug("Header preparation works correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test prepare_headers: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_prepare_payload_dict(sample_custom_agent):
|
||||
try:
|
||||
payload_dict = {"key": "value", "number": 123}
|
||||
prepared_payload = sample_custom_agent._prepare_payload(
|
||||
payload_dict
|
||||
)
|
||||
assert isinstance(prepared_payload, str)
|
||||
parsed = json.loads(prepared_payload)
|
||||
assert parsed["key"] == "value"
|
||||
assert parsed["number"] == 123
|
||||
logger.debug("Dictionary payload prepared correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test prepare_payload with dict: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_prepare_payload_string(sample_custom_agent):
|
||||
try:
|
||||
payload_string = '{"test": "value"}'
|
||||
prepared_payload = sample_custom_agent._prepare_payload(
|
||||
payload_string
|
||||
)
|
||||
assert prepared_payload == payload_string
|
||||
logger.debug("String payload prepared correctly")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to test prepare_payload with string: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def test_prepare_payload_bytes(sample_custom_agent):
|
||||
try:
|
||||
payload_bytes = b'{"test": "value"}'
|
||||
prepared_payload = sample_custom_agent._prepare_payload(
|
||||
payload_bytes
|
||||
)
|
||||
assert prepared_payload == payload_bytes
|
||||
logger.debug("Bytes payload prepared correctly")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to test prepare_payload with bytes: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def test_parse_response_success(sample_custom_agent):
|
||||
try:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = '{"message": "success"}'
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {"message": "success"}
|
||||
|
||||
parsed_response = sample_custom_agent._parse_response(
|
||||
mock_response
|
||||
)
|
||||
assert isinstance(parsed_response, AgentResponse)
|
||||
assert parsed_response.status_code == 200
|
||||
assert parsed_response.success is True
|
||||
assert parsed_response.json_data == {"message": "success"}
|
||||
assert parsed_response.error_message is None
|
||||
logger.debug("Successful response parsed correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test parse_response success: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_parse_response_error(sample_custom_agent):
|
||||
try:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = "Not Found"
|
||||
mock_response.headers = {"content-type": "text/plain"}
|
||||
|
||||
parsed_response = sample_custom_agent._parse_response(
|
||||
mock_response
|
||||
)
|
||||
assert isinstance(parsed_response, AgentResponse)
|
||||
assert parsed_response.status_code == 404
|
||||
assert parsed_response.success is False
|
||||
assert parsed_response.error_message == "HTTP 404"
|
||||
logger.debug("Error response parsed correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test parse_response error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_extract_content_openai_format(sample_custom_agent):
|
||||
try:
|
||||
openai_response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "This is the response content"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
extracted_content = sample_custom_agent._extract_content(
|
||||
openai_response
|
||||
)
|
||||
assert extracted_content == "This is the response content"
|
||||
logger.debug("OpenAI format content extracted correctly")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to test extract_content OpenAI format: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def test_extract_content_anthropic_format(sample_custom_agent):
|
||||
try:
|
||||
anthropic_response = {
|
||||
"content": [
|
||||
{"text": "First part "},
|
||||
{"text": "second part"},
|
||||
]
|
||||
}
|
||||
extracted_content = sample_custom_agent._extract_content(
|
||||
anthropic_response
|
||||
)
|
||||
assert extracted_content == "First part second part"
|
||||
logger.debug("Anthropic format content extracted correctly")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to test extract_content Anthropic format: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def test_extract_content_generic_format(sample_custom_agent):
|
||||
try:
|
||||
generic_response = {"text": "Generic response text"}
|
||||
extracted_content = sample_custom_agent._extract_content(
|
||||
generic_response
|
||||
)
|
||||
assert extracted_content == "Generic response text"
|
||||
logger.debug("Generic format content extracted correctly")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to test extract_content generic format: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@patch("swarms.structs.custom_agent.httpx.Client")
|
||||
def test_run_success(mock_client_class, sample_custom_agent):
|
||||
try:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = (
|
||||
'{"choices": [{"message": {"content": "Success"}}]}'
|
||||
)
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Success"}}]
|
||||
}
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.__enter__ = Mock(
|
||||
return_value=mock_client_instance
|
||||
)
|
||||
mock_client_instance.__exit__ = Mock(return_value=None)
|
||||
mock_client_instance.post.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
test_payload = {"message": "test"}
|
||||
result = sample_custom_agent.run(test_payload)
|
||||
|
||||
assert result == "Success"
|
||||
logger.info("Run method executed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test run success: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@patch("swarms.structs.custom_agent.httpx.Client")
|
||||
def test_run_error_response(mock_client_class, sample_custom_agent):
|
||||
try:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.__enter__ = Mock(
|
||||
return_value=mock_client_instance
|
||||
)
|
||||
mock_client_instance.__exit__ = Mock(return_value=None)
|
||||
mock_client_instance.post.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
test_payload = {"message": "test"}
|
||||
result = sample_custom_agent.run(test_payload)
|
||||
|
||||
assert "Error: HTTP 500" in result
|
||||
logger.debug("Error response handled correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test run error response: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@patch("swarms.structs.custom_agent.httpx.Client")
|
||||
def test_run_request_error(mock_client_class, sample_custom_agent):
|
||||
try:
|
||||
import httpx
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.__enter__ = Mock(
|
||||
return_value=mock_client_instance
|
||||
)
|
||||
mock_client_instance.__exit__ = Mock(return_value=None)
|
||||
mock_client_instance.post.side_effect = httpx.RequestError(
|
||||
"Connection failed"
|
||||
)
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
test_payload = {"message": "test"}
|
||||
result = sample_custom_agent.run(test_payload)
|
||||
|
||||
assert "Request error" in result
|
||||
logger.debug("Request error handled correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test run request error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not ASYNC_AVAILABLE, reason="pytest-asyncio not installed"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("swarms.structs.custom_agent.httpx.AsyncClient")
|
||||
async def test_run_async_success(
|
||||
mock_async_client_class, sample_custom_agent
|
||||
):
|
||||
try:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = (
|
||||
'{"content": [{"text": "Async Success"}]}'
|
||||
)
|
||||
mock_response.json.return_value = {
|
||||
"content": [{"text": "Async Success"}]
|
||||
}
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(
|
||||
return_value=mock_client_instance
|
||||
)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_instance.post = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_async_client_class.return_value = mock_client_instance
|
||||
|
||||
test_payload = {"message": "test"}
|
||||
result = await sample_custom_agent.run_async(test_payload)
|
||||
|
||||
assert result == "Async Success"
|
||||
logger.info("Run_async method executed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test run_async success: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not ASYNC_AVAILABLE, reason="pytest-asyncio not installed"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("swarms.structs.custom_agent.httpx.AsyncClient")
|
||||
async def test_run_async_error_response(
|
||||
mock_async_client_class, sample_custom_agent
|
||||
):
|
||||
try:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(
|
||||
return_value=mock_client_instance
|
||||
)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_instance.post = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_async_client_class.return_value = mock_client_instance
|
||||
|
||||
test_payload = {"message": "test"}
|
||||
result = await sample_custom_agent.run_async(test_payload)
|
||||
|
||||
assert "Error: HTTP 400" in result
|
||||
logger.debug("Async error response handled correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test run_async error response: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_agent_response_dataclass():
|
||||
try:
|
||||
agent_response_instance = AgentResponse(
|
||||
status_code=200,
|
||||
content="Success",
|
||||
headers={"content-type": "application/json"},
|
||||
json_data={"key": "value"},
|
||||
success=True,
|
||||
error_message=None,
|
||||
)
|
||||
assert agent_response_instance.status_code == 200
|
||||
assert agent_response_instance.content == "Success"
|
||||
assert agent_response_instance.success is True
|
||||
assert agent_response_instance.error_message is None
|
||||
logger.debug("AgentResponse dataclass created correctly")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test AgentResponse dataclass: {e}")
|
||||
raise
|
||||
Loading…
Reference in new issue