[NEW DOCS] [Updated swarms api documentation] [IMPR] [Improved swarms output types and automatic port]
parent
e67f4161f3
commit
41c7004dcd
@ -0,0 +1,85 @@
|
||||
from swarms.structs.conversation import (
|
||||
Conversation,
|
||||
get_conversation_dir,
|
||||
)
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
def cleanup_test_conversations():
|
||||
"""Clean up test conversation files after running the example."""
|
||||
conv_dir = get_conversation_dir()
|
||||
if os.path.exists(conv_dir):
|
||||
shutil.rmtree(conv_dir)
|
||||
print(
|
||||
f"\nCleaned up test conversations directory: {conv_dir}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Example 1: In-memory only conversation (no saving)
|
||||
print("\nExample 1: In-memory conversation (no saving)")
|
||||
conv_memory = Conversation(
|
||||
name="memory_only_chat",
|
||||
save_enabled=False, # Don't save to disk
|
||||
autosave=False,
|
||||
)
|
||||
conv_memory.add("user", "This conversation won't be saved!")
|
||||
conv_memory.display_conversation()
|
||||
|
||||
# Example 2: Conversation with autosaving
|
||||
print("\nExample 2: Conversation with autosaving")
|
||||
conversation_dir = get_conversation_dir()
|
||||
print(f"Conversations will be stored in: {conversation_dir}")
|
||||
|
||||
conv_autosave = Conversation(
|
||||
name="autosave_chat",
|
||||
conversations_dir=conversation_dir,
|
||||
save_enabled=True, # Enable saving
|
||||
autosave=True, # Enable autosaving
|
||||
)
|
||||
print(f"Created new conversation with ID: {conv_autosave.id}")
|
||||
print(
|
||||
f"This conversation is saved at: {conv_autosave.save_filepath}"
|
||||
)
|
||||
|
||||
# Add some messages (each will be autosaved)
|
||||
conv_autosave.add("user", "Hello! How are you?")
|
||||
conv_autosave.add(
|
||||
"assistant",
|
||||
"I'm doing well, thank you! How can I help you today?",
|
||||
)
|
||||
|
||||
# Example 3: Load from specific file
|
||||
print("\nExample 3: Load from specific file")
|
||||
custom_file = os.path.join(conversation_dir, "custom_chat.json")
|
||||
|
||||
# Create a conversation and save it to a custom file
|
||||
conv_custom = Conversation(
|
||||
name="custom_chat",
|
||||
save_filepath=custom_file,
|
||||
save_enabled=True,
|
||||
)
|
||||
conv_custom.add("user", "This is a custom saved conversation")
|
||||
conv_custom.add(
|
||||
"assistant", "I'll be saved in a custom location!"
|
||||
)
|
||||
conv_custom.save_as_json()
|
||||
|
||||
# Now load it specifically
|
||||
loaded_conv = Conversation.load_conversation(
|
||||
name="custom_chat", load_filepath=custom_file
|
||||
)
|
||||
print("Loaded custom conversation:")
|
||||
loaded_conv.display_conversation()
|
||||
|
||||
# List all saved conversations
|
||||
print("\nAll saved conversations:")
|
||||
conversations = Conversation.list_conversations(conversation_dir)
|
||||
for conv_info in conversations:
|
||||
print(
|
||||
f"- {conv_info['name']} (ID: {conv_info['id']}, Created: {conv_info['created_at']})"
|
||||
)
|
||||
|
||||
|
||||
main()
|
@ -0,0 +1,9 @@
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
agent = Agent(
|
||||
agent_name="test",
|
||||
agent_description="test",
|
||||
system_prompt="test",
|
||||
)
|
||||
|
||||
print(agent.list_output_types())
|
@ -0,0 +1,38 @@
|
||||
AGGREGATOR_SYSTEM_PROMPT = """You are a highly skilled Aggregator Agent responsible for analyzing, synthesizing, and summarizing conversations between multiple AI agents. Your primary goal is to distill complex multi-agent interactions into clear, actionable insights.
|
||||
|
||||
Key Responsibilities:
|
||||
1. Conversation Analysis:
|
||||
- Identify the main topics and themes discussed
|
||||
- Track the progression of ideas and problem-solving approaches
|
||||
- Recognize key decisions and turning points in the conversation
|
||||
- Note any conflicts, agreements, or important conclusions reached
|
||||
|
||||
2. Agent Contribution Assessment:
|
||||
- Evaluate each agent's unique contributions to the discussion
|
||||
- Highlight complementary perspectives and insights
|
||||
- Identify any knowledge gaps or areas requiring further exploration
|
||||
- Recognize patterns in agent interactions and collaborative dynamics
|
||||
|
||||
3. Summary Generation Guidelines:
|
||||
- Begin with a high-level overview of the conversation's purpose and outcome
|
||||
- Structure the summary in a logical, hierarchical manner
|
||||
- Prioritize critical information while maintaining context
|
||||
- Include specific examples or quotes when they significantly impact understanding
|
||||
- Maintain objectivity while synthesizing different viewpoints
|
||||
- Highlight actionable insights and next steps if applicable
|
||||
|
||||
4. Quality Standards:
|
||||
- Ensure accuracy in representing each agent's contributions
|
||||
- Maintain clarity and conciseness without oversimplifying
|
||||
- Use consistent terminology throughout the summary
|
||||
- Preserve important technical details and domain-specific language
|
||||
- Flag any uncertainties or areas needing clarification
|
||||
|
||||
5. Output Format:
|
||||
- Present information in a structured, easy-to-read format
|
||||
- Use bullet points or sections for better readability when appropriate
|
||||
- Include a brief conclusion or recommendation section if relevant
|
||||
- Maintain professional and neutral tone throughout
|
||||
|
||||
Remember: Your role is crucial in making complex multi-agent discussions accessible and actionable. Focus on extracting value from the conversation while maintaining the integrity of each agent's contributions.
|
||||
"""
|
@ -0,0 +1,71 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of tokens used in the prompt",
|
||||
)
|
||||
completion_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of tokens used in the completion",
|
||||
)
|
||||
total_tokens: Optional[int] = Field(
|
||||
default=None, description="Total number of tokens used"
|
||||
)
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of the model used for generation",
|
||||
)
|
||||
temperature: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Temperature setting used for generation",
|
||||
)
|
||||
top_p: Optional[float] = Field(
|
||||
default=None, description="Top-p setting used for generation"
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum number of tokens to generate",
|
||||
)
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Frequency penalty used for generation",
|
||||
)
|
||||
presence_penalty: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Presence penalty used for generation",
|
||||
)
|
||||
|
||||
|
||||
class AgentCompletionResponse(BaseModel):
|
||||
id: Optional[str] = Field(
|
||||
default=None, description="Unique identifier for the response"
|
||||
)
|
||||
agent_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of the agent that generated the response",
|
||||
)
|
||||
agent_description: Optional[str] = Field(
|
||||
default=None, description="Description of the agent"
|
||||
)
|
||||
outputs: Optional[List[Any]] = Field(
|
||||
default=None,
|
||||
description="List of outputs generated by the agent",
|
||||
)
|
||||
usage: Optional[Usage] = Field(
|
||||
default=None, description="Token usage statistics"
|
||||
)
|
||||
model_config: Optional[ModelConfig] = Field(
|
||||
default=None, description="Model configuration"
|
||||
)
|
||||
timestamp: Optional[str] = Field(
|
||||
default_factory=lambda: datetime.now().isoformat(),
|
||||
description="Timestamp of when the response was generated",
|
||||
)
|
@ -0,0 +1,9 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ConversationSchema(BaseModel):
|
||||
time_enabled: Optional[bool] = Field(default=False)
|
||||
message_id_on: Optional[bool] = Field(default=True)
|
||||
autosave: Optional[bool] = Field(default=False)
|
||||
count_tokens: Optional[bool] = Field(default=False)
|
@ -0,0 +1,84 @@
|
||||
from swarms.structs.agent import Agent
|
||||
from typing import List, Callable
|
||||
from swarms.structs.conversation import Conversation
|
||||
from swarms.structs.multi_agent_exec import run_agents_concurrently
|
||||
from swarms.utils.history_output_formatter import (
|
||||
history_output_formatter,
|
||||
HistoryOutputType,
|
||||
)
|
||||
|
||||
from swarms.prompts.agent_conversation_aggregator import (
|
||||
AGGREGATOR_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def aggregator_agent_task_prompt(
|
||||
task: str, workers: List[Agent], conversation: Conversation
|
||||
):
|
||||
return f"""
|
||||
Please analyze and summarize the following multi-agent conversation, following your guidelines for comprehensive synthesis:
|
||||
|
||||
Conversation Context:
|
||||
Original Task: {task}
|
||||
Number of Participating Agents: {len(workers)}
|
||||
|
||||
Conversation Content:
|
||||
{conversation.get_str()}
|
||||
|
||||
Please provide a 3,000 word comprehensive summary report of the conversation.
|
||||
"""
|
||||
|
||||
|
||||
def aggregate(
|
||||
workers: List[Callable],
|
||||
task: str = None,
|
||||
type: HistoryOutputType = "all",
|
||||
aggregator_model_name: str = "anthropic/claude-3-sonnet-20240229",
|
||||
):
|
||||
"""
|
||||
Aggregate a list of tasks into a single task.
|
||||
"""
|
||||
|
||||
if task is None:
|
||||
raise ValueError("Task is required in the aggregator block")
|
||||
|
||||
if workers is None:
|
||||
raise ValueError(
|
||||
"Workers is required in the aggregator block"
|
||||
)
|
||||
|
||||
if not isinstance(workers, list):
|
||||
raise ValueError("Workers must be a list of Callable")
|
||||
|
||||
if not all(isinstance(worker, Callable) for worker in workers):
|
||||
raise ValueError("Workers must be a list of Callable")
|
||||
|
||||
conversation = Conversation()
|
||||
|
||||
aggregator_agent = Agent(
|
||||
agent_name="Aggregator",
|
||||
agent_description="Expert agent specializing in analyzing and synthesizing multi-agent conversations",
|
||||
system_prompt=AGGREGATOR_SYSTEM_PROMPT,
|
||||
max_loops=1,
|
||||
model_name=aggregator_model_name,
|
||||
output_type="final",
|
||||
max_tokens=4000,
|
||||
)
|
||||
|
||||
results = run_agents_concurrently(agents=workers, task=task)
|
||||
|
||||
# Zip the results with the agents
|
||||
for result, agent in zip(results, workers):
|
||||
conversation.add(content=result, role=agent.agent_name)
|
||||
|
||||
final_result = aggregator_agent.run(
|
||||
task=aggregator_agent_task_prompt(task, workers, conversation)
|
||||
)
|
||||
|
||||
conversation.add(
|
||||
content=final_result, role=aggregator_agent.agent_name
|
||||
)
|
||||
|
||||
return history_output_formatter(
|
||||
conversation=conversation, type=type
|
||||
)
|
@ -1,6 +0,0 @@
|
||||
from swarms.utils.history_output_formatter import (
|
||||
HistoryOutputType as OutputType,
|
||||
)
|
||||
|
||||
# Use the OutputType for type annotations
|
||||
output_type: OutputType # OutputType now includes 'xml'
|
@ -0,0 +1,23 @@
|
||||
from typing import Literal
|
||||
|
||||
HistoryOutputType = Literal[
|
||||
"list",
|
||||
"dict",
|
||||
"dictionary",
|
||||
"string",
|
||||
"str",
|
||||
"final",
|
||||
"last",
|
||||
"json",
|
||||
"all",
|
||||
"yaml",
|
||||
"xml",
|
||||
# "dict-final",
|
||||
"dict-all-except-first",
|
||||
"str-all-except-first",
|
||||
"basemodel",
|
||||
]
|
||||
|
||||
OutputType = HistoryOutputType
|
||||
|
||||
output_type: HistoryOutputType # OutputType now includes 'xml'
|
@ -1,127 +0,0 @@
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
|
||||
from clusterops import (
|
||||
execute_on_gpu,
|
||||
execute_on_multiple_gpus,
|
||||
list_available_gpus,
|
||||
execute_with_all_cpu_cores,
|
||||
execute_on_cpu,
|
||||
)
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
logger = initialize_logger(log_folder="clusterops_wrapper")
|
||||
|
||||
|
||||
def exec_callable_with_clusterops(
|
||||
device: str = "cpu",
|
||||
device_id: int = 1,
|
||||
all_cores: bool = True,
|
||||
all_gpus: bool = False,
|
||||
func: callable = None,
|
||||
enable_logging: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Executes a given function on a specified device, either CPU or GPU.
|
||||
|
||||
This method attempts to execute a given function on a specified device, either CPU or GPU. It logs the device selection and the number of cores or GPU ID used. If the device is set to CPU, it can use all available cores or a specific core specified by `device_id`. If the device is set to GPU, it uses the GPU specified by `device_id`.
|
||||
|
||||
Args:
|
||||
device (str, optional): The device to use for execution. Defaults to "cpu".
|
||||
device_id (int, optional): The ID of the GPU to use if device is set to "gpu". Defaults to 0.
|
||||
all_cores (bool, optional): If True, uses all available CPU cores. Defaults to True.
|
||||
all_gpus (bool, optional): If True, uses all available GPUs. Defaults to False.
|
||||
func (callable): The function to execute.
|
||||
enable_logging (bool, optional): If True, enables logging. Defaults to True.
|
||||
*args: Additional positional arguments to be passed to the execution method.
|
||||
**kwargs: Additional keyword arguments to be passed to the execution method.
|
||||
|
||||
Returns:
|
||||
Any: The result of the execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid device is specified.
|
||||
Exception: If any other error occurs during execution.
|
||||
"""
|
||||
if func is None:
|
||||
raise ValueError("A callable function must be provided")
|
||||
|
||||
try:
|
||||
if enable_logging:
|
||||
logger.info(f"Attempting to run on device: {device}")
|
||||
device = device.lower()
|
||||
|
||||
# Check if the platform is Windows and do nothing if true
|
||||
if platform.system() == "Windows":
|
||||
if enable_logging:
|
||||
logger.info(
|
||||
"Platform is Windows, not executing on device."
|
||||
)
|
||||
return None
|
||||
|
||||
if device == "cpu":
|
||||
if enable_logging:
|
||||
logger.info("Device set to CPU")
|
||||
|
||||
if all_cores:
|
||||
if enable_logging:
|
||||
logger.info("Using all CPU cores")
|
||||
return execute_with_all_cpu_cores(
|
||||
func, *args, **kwargs
|
||||
)
|
||||
|
||||
if device_id is not None:
|
||||
if enable_logging:
|
||||
logger.info(
|
||||
f"Using specific CPU core: {device_id}"
|
||||
)
|
||||
return execute_on_cpu(
|
||||
device_id, func, *args, **kwargs
|
||||
)
|
||||
|
||||
elif device == "gpu":
|
||||
if enable_logging:
|
||||
logger.info("Device set to GPU")
|
||||
|
||||
if all_gpus:
|
||||
if enable_logging:
|
||||
logger.info("Using all available GPUs")
|
||||
gpus = [int(gpu) for gpu in list_available_gpus()]
|
||||
return execute_on_multiple_gpus(
|
||||
gpus, func, *args, **kwargs
|
||||
)
|
||||
|
||||
if enable_logging:
|
||||
logger.info(f"Using GPU device ID: {device_id}")
|
||||
return execute_on_gpu(device_id, func, *args, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid device specified: {device}. Supported devices are 'cpu' and 'gpu'."
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
if enable_logging:
|
||||
logger.error(
|
||||
f"Invalid device or configuration specified: {e}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
if enable_logging:
|
||||
logger.error(f"An error occurred during execution: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# def test_clusterops(x):
|
||||
# return x + 1
|
||||
|
||||
# example = exec_callable_with_clusterops(
|
||||
# device="cpu",
|
||||
# all_cores=True,
|
||||
# func = test_clusterops,
|
||||
# )
|
||||
|
||||
# print(example)
|
@ -0,0 +1,567 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import yaml
|
||||
from swarms.structs.conversation import (
|
||||
Conversation,
|
||||
generate_conversation_id,
|
||||
)
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests for the Conversation class"""
|
||||
test_results = []
|
||||
|
||||
def run_test(test_func):
|
||||
try:
|
||||
test_func()
|
||||
test_results.append(f"✅ {test_func.__name__} passed")
|
||||
except Exception as e:
|
||||
test_results.append(
|
||||
f"❌ {test_func.__name__} failed: {str(e)}"
|
||||
)
|
||||
|
||||
def test_basic_initialization():
|
||||
"""Test basic initialization of Conversation"""
|
||||
conv = Conversation()
|
||||
assert conv.id is not None
|
||||
assert conv.conversation_history is not None
|
||||
assert isinstance(conv.conversation_history, list)
|
||||
|
||||
# Test with custom ID
|
||||
custom_id = generate_conversation_id()
|
||||
conv_with_id = Conversation(id=custom_id)
|
||||
assert conv_with_id.id == custom_id
|
||||
|
||||
# Test with custom name
|
||||
conv_with_name = Conversation(name="Test Conversation")
|
||||
assert conv_with_name.name == "Test Conversation"
|
||||
|
||||
def test_initialization_with_settings():
|
||||
"""Test initialization with various settings"""
|
||||
conv = Conversation(
|
||||
system_prompt="Test system prompt",
|
||||
time_enabled=True,
|
||||
autosave=True,
|
||||
token_count=True,
|
||||
provider="in-memory",
|
||||
context_length=4096,
|
||||
rules="Test rules",
|
||||
custom_rules_prompt="Custom rules",
|
||||
user="TestUser:",
|
||||
save_as_yaml=True,
|
||||
save_as_json_bool=True,
|
||||
)
|
||||
|
||||
# Test all settings
|
||||
assert conv.system_prompt == "Test system prompt"
|
||||
assert conv.time_enabled is True
|
||||
assert conv.autosave is True
|
||||
assert conv.token_count is True
|
||||
assert conv.provider == "in-memory"
|
||||
assert conv.context_length == 4096
|
||||
assert conv.rules == "Test rules"
|
||||
assert conv.custom_rules_prompt == "Custom rules"
|
||||
assert conv.user == "TestUser:"
|
||||
assert conv.save_as_yaml is True
|
||||
assert conv.save_as_json_bool is True
|
||||
|
||||
def test_message_manipulation():
|
||||
"""Test adding, deleting, and updating messages"""
|
||||
conv = Conversation()
|
||||
|
||||
# Test adding messages with different content types
|
||||
conv.add("user", "Hello") # String content
|
||||
conv.add("assistant", {"response": "Hi"}) # Dict content
|
||||
conv.add("system", ["Hello", "Hi"]) # List content
|
||||
|
||||
assert len(conv.conversation_history) == 3
|
||||
assert isinstance(
|
||||
conv.conversation_history[1]["content"], dict
|
||||
)
|
||||
assert isinstance(
|
||||
conv.conversation_history[2]["content"], list
|
||||
)
|
||||
|
||||
# Test adding multiple messages
|
||||
conv.add_multiple(
|
||||
["user", "assistant", "system"],
|
||||
["Hi", "Hello there", "System message"],
|
||||
)
|
||||
assert len(conv.conversation_history) == 6
|
||||
|
||||
# Test updating message with different content type
|
||||
conv.update(0, "user", {"updated": "content"})
|
||||
assert isinstance(
|
||||
conv.conversation_history[0]["content"], dict
|
||||
)
|
||||
|
||||
# Test deleting multiple messages
|
||||
conv.delete(0)
|
||||
conv.delete(0)
|
||||
assert len(conv.conversation_history) == 4
|
||||
|
||||
def test_message_retrieval():
|
||||
"""Test message retrieval methods"""
|
||||
conv = Conversation()
|
||||
|
||||
# Add messages in specific order for testing
|
||||
conv.add("user", "Test message")
|
||||
conv.add("assistant", "Test response")
|
||||
conv.add("system", "System message")
|
||||
|
||||
# Test query - note: messages might have system prompt prepended
|
||||
message = conv.query(0)
|
||||
assert "Test message" in message["content"]
|
||||
|
||||
# Test search with multiple results
|
||||
results = conv.search("Test")
|
||||
assert (
|
||||
len(results) >= 2
|
||||
) # At least two messages should contain "Test"
|
||||
assert any(
|
||||
"Test message" in str(msg["content"]) for msg in results
|
||||
)
|
||||
assert any(
|
||||
"Test response" in str(msg["content"]) for msg in results
|
||||
)
|
||||
|
||||
# Test get_last_message_as_string
|
||||
last_message = conv.get_last_message_as_string()
|
||||
assert "System message" in last_message
|
||||
|
||||
# Test return_messages_as_list
|
||||
messages_list = conv.return_messages_as_list()
|
||||
assert (
|
||||
len(messages_list) >= 3
|
||||
) # At least our 3 added messages
|
||||
assert any("Test message" in msg for msg in messages_list)
|
||||
|
||||
# Test return_messages_as_dictionary
|
||||
messages_dict = conv.return_messages_as_dictionary()
|
||||
assert (
|
||||
len(messages_dict) >= 3
|
||||
) # At least our 3 added messages
|
||||
assert all(isinstance(m, dict) for m in messages_dict)
|
||||
assert all(
|
||||
{"role", "content"} <= set(m.keys())
|
||||
for m in messages_dict
|
||||
)
|
||||
|
||||
# Test get_final_message and content
|
||||
assert "System message" in conv.get_final_message()
|
||||
assert "System message" in conv.get_final_message_content()
|
||||
|
||||
# Test return_all_except_first
|
||||
remaining_messages = conv.return_all_except_first()
|
||||
assert (
|
||||
len(remaining_messages) >= 2
|
||||
) # At least 2 messages after removing first
|
||||
|
||||
# Test return_all_except_first_string
|
||||
remaining_string = conv.return_all_except_first_string()
|
||||
assert isinstance(remaining_string, str)
|
||||
|
||||
def test_saving_loading():
|
||||
"""Test saving and loading conversation"""
|
||||
# Test with save_enabled
|
||||
conv = Conversation(
|
||||
save_enabled=True,
|
||||
conversations_dir="./test_conversations",
|
||||
)
|
||||
conv.add("user", "Test save message")
|
||||
|
||||
# Test save_as_json
|
||||
test_file = os.path.join(
|
||||
"./test_conversations", "test_conversation.json"
|
||||
)
|
||||
conv.save_as_json(test_file)
|
||||
assert os.path.exists(test_file)
|
||||
|
||||
# Test load_from_json
|
||||
new_conv = Conversation()
|
||||
new_conv.load_from_json(test_file)
|
||||
assert len(new_conv.conversation_history) == 1
|
||||
assert (
|
||||
new_conv.conversation_history[0]["content"]
|
||||
== "Test save message"
|
||||
)
|
||||
|
||||
# Test class method load_conversation
|
||||
loaded_conv = Conversation.load_conversation(
|
||||
name=conv.id, conversations_dir="./test_conversations"
|
||||
)
|
||||
assert loaded_conv.id == conv.id
|
||||
|
||||
# Cleanup
|
||||
os.remove(test_file)
|
||||
os.rmdir("./test_conversations")
|
||||
|
||||
def test_output_formats():
|
||||
"""Test different output formats"""
|
||||
conv = Conversation()
|
||||
conv.add("user", "Test message")
|
||||
conv.add("assistant", {"response": "Test"})
|
||||
|
||||
# Test JSON output
|
||||
json_output = conv.to_json()
|
||||
assert isinstance(json_output, str)
|
||||
parsed_json = json.loads(json_output)
|
||||
assert len(parsed_json) == 2
|
||||
|
||||
# Test dict output
|
||||
dict_output = conv.to_dict()
|
||||
assert isinstance(dict_output, list)
|
||||
assert len(dict_output) == 2
|
||||
|
||||
# Test YAML output
|
||||
yaml_output = conv.to_yaml()
|
||||
assert isinstance(yaml_output, str)
|
||||
parsed_yaml = yaml.safe_load(yaml_output)
|
||||
assert len(parsed_yaml) == 2
|
||||
|
||||
# Test return_json
|
||||
json_str = conv.return_json()
|
||||
assert isinstance(json_str, str)
|
||||
assert len(json.loads(json_str)) == 2
|
||||
|
||||
def test_memory_management():
|
||||
"""Test memory management functions"""
|
||||
conv = Conversation()
|
||||
|
||||
# Test clear
|
||||
conv.add("user", "Test message")
|
||||
conv.clear()
|
||||
assert len(conv.conversation_history) == 0
|
||||
|
||||
# Test clear_memory
|
||||
conv.add("user", "Test message")
|
||||
conv.clear_memory()
|
||||
assert len(conv.conversation_history) == 0
|
||||
|
||||
# Test batch operations
|
||||
messages = [
|
||||
{"role": "user", "content": "Message 1"},
|
||||
{"role": "assistant", "content": "Response 1"},
|
||||
]
|
||||
conv.batch_add(messages)
|
||||
assert len(conv.conversation_history) == 2
|
||||
|
||||
# Test truncate_memory_with_tokenizer
|
||||
if conv.tokenizer: # Only if tokenizer is available
|
||||
conv.truncate_memory_with_tokenizer()
|
||||
assert len(conv.conversation_history) > 0
|
||||
|
||||
def test_conversation_metadata():
|
||||
"""Test conversation metadata and listing"""
|
||||
test_dir = "./test_conversations_metadata"
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Create a conversation with metadata
|
||||
conv = Conversation(
|
||||
name="Test Conv",
|
||||
system_prompt="System",
|
||||
rules="Rules",
|
||||
custom_rules_prompt="Custom",
|
||||
conversations_dir=test_dir,
|
||||
save_enabled=True,
|
||||
autosave=True,
|
||||
)
|
||||
|
||||
# Add a message to trigger save
|
||||
conv.add("user", "Test message")
|
||||
|
||||
# Give a small delay for autosave
|
||||
time.sleep(0.1)
|
||||
|
||||
# List conversations and verify
|
||||
conversations = Conversation.list_conversations(test_dir)
|
||||
assert len(conversations) >= 1
|
||||
found_conv = next(
|
||||
(
|
||||
c
|
||||
for c in conversations
|
||||
if c["name"] == "Test Conv"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert found_conv is not None
|
||||
assert found_conv["id"] == conv.id
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
if os.path.exists(test_dir):
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
def test_time_enabled_messages():
|
||||
"""Test time-enabled messages"""
|
||||
conv = Conversation(time_enabled=True)
|
||||
conv.add("user", "Time test")
|
||||
|
||||
# Verify timestamp in message
|
||||
message = conv.conversation_history[0]
|
||||
assert "timestamp" in message
|
||||
assert isinstance(message["timestamp"], str)
|
||||
|
||||
# Verify time in content when time_enabled is True
|
||||
assert "Time:" in message["content"]
|
||||
|
||||
def test_provider_specific():
|
||||
"""Test provider-specific functionality"""
|
||||
# Test in-memory provider
|
||||
conv_memory = Conversation(provider="in-memory")
|
||||
conv_memory.add("user", "Test")
|
||||
assert len(conv_memory.conversation_history) == 1
|
||||
|
||||
# Test mem0 provider if available
|
||||
try:
|
||||
conv_mem0 = Conversation(provider="mem0")
|
||||
conv_mem0.add("user", "Test")
|
||||
# Add appropriate assertions based on mem0 behavior
|
||||
except:
|
||||
pass # Skip if mem0 is not available
|
||||
|
||||
def test_tool_output():
|
||||
"""Test tool output handling"""
|
||||
conv = Conversation()
|
||||
tool_output = {
|
||||
"tool_name": "test_tool",
|
||||
"output": "test result",
|
||||
}
|
||||
conv.add_tool_output_to_agent("tool", tool_output)
|
||||
|
||||
assert len(conv.conversation_history) == 1
|
||||
assert conv.conversation_history[0]["role"] == "tool"
|
||||
assert conv.conversation_history[0]["content"] == tool_output
|
||||
|
||||
def test_autosave_functionality():
|
||||
"""Test autosave functionality and related features"""
|
||||
test_dir = "./test_conversations_autosave"
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Test with autosave and save_enabled True
|
||||
conv = Conversation(
|
||||
autosave=True,
|
||||
save_enabled=True,
|
||||
conversations_dir=test_dir,
|
||||
name="autosave_test",
|
||||
)
|
||||
|
||||
# Add a message and verify it was auto-saved
|
||||
conv.add("user", "Test autosave message")
|
||||
save_path = os.path.join(test_dir, f"{conv.id}.json")
|
||||
|
||||
# Give a small delay for autosave to complete
|
||||
time.sleep(0.1)
|
||||
|
||||
assert os.path.exists(
|
||||
save_path
|
||||
), f"Save file not found at {save_path}"
|
||||
|
||||
# Load the saved conversation and verify content
|
||||
loaded_conv = Conversation.load_conversation(
|
||||
name=conv.id, conversations_dir=test_dir
|
||||
)
|
||||
found_message = False
|
||||
for msg in loaded_conv.conversation_history:
|
||||
if "Test autosave message" in str(msg["content"]):
|
||||
found_message = True
|
||||
break
|
||||
assert (
|
||||
found_message
|
||||
), "Message not found in loaded conversation"
|
||||
|
||||
# Clean up first conversation files
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
# Test with save_enabled=False
|
||||
conv_no_save = Conversation(
|
||||
autosave=False, # Changed to False to prevent autosave
|
||||
save_enabled=False,
|
||||
conversations_dir=test_dir,
|
||||
)
|
||||
conv_no_save.add("user", "This shouldn't be saved")
|
||||
save_path_no_save = os.path.join(
|
||||
test_dir, f"{conv_no_save.id}.json"
|
||||
)
|
||||
time.sleep(0.1) # Give time for potential save
|
||||
assert not os.path.exists(
|
||||
save_path_no_save
|
||||
), "File should not exist when save_enabled is False"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
if os.path.exists(test_dir):
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
def test_advanced_message_handling():
|
||||
"""Test advanced message handling features"""
|
||||
conv = Conversation()
|
||||
|
||||
# Test adding messages with metadata
|
||||
metadata = {"timestamp": "2024-01-01", "session_id": "123"}
|
||||
conv.add("user", "Test with metadata", metadata=metadata)
|
||||
|
||||
# Test batch operations with different content types
|
||||
messages = [
|
||||
{"role": "user", "content": "Message 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": {"response": "Complex response"},
|
||||
},
|
||||
{"role": "system", "content": ["Multiple", "Items"]},
|
||||
]
|
||||
conv.batch_add(messages)
|
||||
assert (
|
||||
len(conv.conversation_history) == 4
|
||||
) # Including the first message
|
||||
|
||||
# Test message format consistency
|
||||
for msg in conv.conversation_history:
|
||||
assert "role" in msg
|
||||
assert "content" in msg
|
||||
if "timestamp" in msg:
|
||||
assert isinstance(msg["timestamp"], str)
|
||||
|
||||
def test_conversation_metadata_handling():
|
||||
"""Test handling of conversation metadata and attributes"""
|
||||
test_dir = "./test_conversations_metadata_handling"
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Test initialization with all optional parameters
|
||||
conv = Conversation(
|
||||
name="Test Conv",
|
||||
system_prompt="System Prompt",
|
||||
time_enabled=True,
|
||||
context_length=2048,
|
||||
rules="Test Rules",
|
||||
custom_rules_prompt="Custom Rules",
|
||||
user="CustomUser:",
|
||||
provider="in-memory",
|
||||
conversations_dir=test_dir,
|
||||
save_enabled=True,
|
||||
)
|
||||
|
||||
# Verify all attributes are set correctly
|
||||
assert conv.name == "Test Conv"
|
||||
assert conv.system_prompt == "System Prompt"
|
||||
assert conv.time_enabled is True
|
||||
assert conv.context_length == 2048
|
||||
assert conv.rules == "Test Rules"
|
||||
assert conv.custom_rules_prompt == "Custom Rules"
|
||||
assert conv.user == "CustomUser:"
|
||||
assert conv.provider == "in-memory"
|
||||
|
||||
# Test saving and loading preserves metadata
|
||||
conv.save_as_json()
|
||||
|
||||
# Load using load_conversation
|
||||
loaded_conv = Conversation.load_conversation(
|
||||
name=conv.id, conversations_dir=test_dir
|
||||
)
|
||||
|
||||
# Verify metadata was preserved
|
||||
assert loaded_conv.name == "Test Conv"
|
||||
assert loaded_conv.system_prompt == "System Prompt"
|
||||
assert loaded_conv.rules == "Test Rules"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
def test_time_enabled_features():
|
||||
"""Test time-enabled message features"""
|
||||
conv = Conversation(time_enabled=True)
|
||||
|
||||
# Add message and verify timestamp
|
||||
conv.add("user", "Time test message")
|
||||
message = conv.conversation_history[0]
|
||||
|
||||
# Verify timestamp format
|
||||
assert "timestamp" in message
|
||||
try:
|
||||
datetime.datetime.fromisoformat(message["timestamp"])
|
||||
except ValueError:
|
||||
assert False, "Invalid timestamp format"
|
||||
|
||||
# Verify time in content
|
||||
assert "Time:" in message["content"]
|
||||
assert (
|
||||
datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
in message["content"]
|
||||
)
|
||||
|
||||
def test_provider_specific_features():
|
||||
"""Test provider-specific features and behaviors"""
|
||||
# Test in-memory provider
|
||||
conv_memory = Conversation(provider="in-memory")
|
||||
conv_memory.add("user", "Test in-memory")
|
||||
assert len(conv_memory.conversation_history) == 1
|
||||
assert (
|
||||
"Test in-memory"
|
||||
in conv_memory.get_last_message_as_string()
|
||||
)
|
||||
|
||||
# Test mem0 provider if available
|
||||
try:
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
# Skip actual mem0 testing since it requires async
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Test invalid provider
|
||||
invalid_provider = "invalid_provider"
|
||||
try:
|
||||
Conversation(provider=invalid_provider)
|
||||
# If we get here, the provider was accepted when it shouldn't have been
|
||||
raise AssertionError(
|
||||
f"Should have raised ValueError for provider '{invalid_provider}'"
|
||||
)
|
||||
except ValueError:
|
||||
# This is the expected behavior
|
||||
pass
|
||||
|
||||
# Run all tests
|
||||
tests = [
|
||||
test_basic_initialization,
|
||||
test_initialization_with_settings,
|
||||
test_message_manipulation,
|
||||
test_message_retrieval,
|
||||
test_saving_loading,
|
||||
test_output_formats,
|
||||
test_memory_management,
|
||||
test_conversation_metadata,
|
||||
test_time_enabled_messages,
|
||||
test_provider_specific,
|
||||
test_tool_output,
|
||||
test_autosave_functionality,
|
||||
test_advanced_message_handling,
|
||||
test_conversation_metadata_handling,
|
||||
test_time_enabled_features,
|
||||
test_provider_specific_features,
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
run_test(test)
|
||||
|
||||
# Print results
|
||||
print("\nTest Results:")
|
||||
for result in test_results:
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
@ -0,0 +1,36 @@
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.ma_blocks import aggregate
|
||||
|
||||
|
||||
agents = [
|
||||
Agent(
|
||||
agent_name="Sector-Financial-Analyst",
|
||||
agent_description="Senior financial analyst at BlackRock.",
|
||||
system_prompt="You are a financial analyst tasked with optimizing asset allocations for a $50B portfolio. Provide clear, quantitative recommendations for each sector.",
|
||||
max_loops=1,
|
||||
model_name="gpt-4o-mini",
|
||||
max_tokens=3000,
|
||||
),
|
||||
Agent(
|
||||
agent_name="Sector-Risk-Analyst",
|
||||
agent_description="Expert risk management analyst.",
|
||||
system_prompt="You are a risk analyst responsible for advising on risk allocation within a $50B portfolio. Provide detailed insights on risk exposures for each sector.",
|
||||
max_loops=1,
|
||||
model_name="gpt-4o-mini",
|
||||
max_tokens=3000,
|
||||
),
|
||||
Agent(
|
||||
agent_name="Tech-Sector-Analyst",
|
||||
agent_description="Technology sector analyst.",
|
||||
system_prompt="You are a tech sector analyst focused on capital and risk allocations. Provide data-backed insights for the tech sector.",
|
||||
max_loops=1,
|
||||
model_name="gpt-4o-mini",
|
||||
max_tokens=3000,
|
||||
),
|
||||
]
|
||||
|
||||
aggregate(
|
||||
workers=agents,
|
||||
task="What is the best sector to invest in?",
|
||||
type="all",
|
||||
)
|
Loading…
Reference in new issue