diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 0c786b66..2acc2759 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -364,6 +364,7 @@ nav: - Agents with Callable Tools: "swarms/examples/agent_with_tools.md" - Agent with Structured Outputs: "swarms/examples/agent_structured_outputs.md" - Agent With MCP Integration: "swarms/examples/agent_with_mcp.md" + - Message Transforms for Context Management: "swarms/structs/transforms.md" - Vision: - Agents with Vision: "swarms/examples/vision_processing.md" - Agent with Multiple Images: "swarms/examples/multiple_images.md" diff --git a/docs/swarms/structs/transforms.md b/docs/swarms/structs/transforms.md new file mode 100644 index 00000000..ae6fe418 --- /dev/null +++ b/docs/swarms/structs/transforms.md @@ -0,0 +1,346 @@ +# Message Transforms: Context Management for Large Conversations + +The Message Transforms system provides intelligent context size management for AI conversations, automatically handling token limits and message count constraints while preserving the most important contextual information. + +## Overview + +Message transforms enable agents to handle long conversations that exceed model context windows by intelligently compressing the conversation history. The system uses a "middle-out" compression strategy that preserves system messages, recent messages, and the beginning of conversations while compressing or removing less critical middle content. + +## Key Features + +- **Automatic Context Management**: Automatically compresses conversations when they approach token or message limits +- **Middle-Out Compression**: Preserves important context (system messages, recent messages, conversation start) while compressing the middle +- **Model-Aware**: Knows context limits for popular models (GPT-4, Claude, etc.) +- **Flexible Configuration**: Highly customizable compression strategies +- **Detailed Logging**: Provides compression statistics and ratios +- **Zero-Configuration Option**: Can work with sensible defaults + +## Core Components + +### TransformConfig + +The configuration class that controls transform behavior: + +```python +@dataclass +class TransformConfig: + enabled: bool = False # Enable/disable transforms + method: str = "middle-out" # Compression method + max_tokens: Optional[int] = None # Token limit override + max_messages: Optional[int] = None # Message limit override + model_name: str = "gpt-4" # Target model for limit detection + preserve_system_messages: bool = True # Always keep system messages + preserve_recent_messages: int = 2 # Number of recent messages to preserve +``` + +### TransformResult + +Contains the results of message transformation: + +```python +@dataclass +class TransformResult: + messages: List[Dict[str, Any]] # Transformed message list + original_token_count: int # Original token count + compressed_token_count: int # New token count after compression + original_message_count: int # Original message count + compressed_message_count: int # New message count after compression + compression_ratio: float # Compression ratio (0.0-1.0) + was_compressed: bool # Whether compression occurred +``` + +### MessageTransforms + +The main transformation engine: + +```python +class MessageTransforms: + def __init__(self, config: TransformConfig): + """Initialize with configuration.""" + + def transform_messages( + self, + messages: List[Dict[str, Any]], + target_model: Optional[str] = None, + ) -> TransformResult: + """Transform messages according to configuration.""" +``` + +## Usage Examples + +### Basic Agent with Transforms + +```python +from swarms import Agent +from swarms.structs.transforms import TransformConfig + +# Initialize agent with transforms enabled +agent = Agent( + agent_name="Trading-Agent", + agent_description="Financial analysis agent", + model_name="gpt-4o", + max_loops=1, + transforms=TransformConfig( + enabled=True, + method="middle-out", + model_name="gpt-4o", + preserve_system_messages=True, + preserve_recent_messages=3, + ), +) + +result = agent.run("Analyze the current market trends...") +``` + +### Dictionary Configuration + +```python +# Alternative dictionary-based configuration +agent = Agent( + agent_name="Analysis-Agent", + model_name="claude-3-sonnet", + transforms={ + "enabled": True, + "method": "middle-out", + "model_name": "claude-3-sonnet", + "preserve_system_messages": True, + "preserve_recent_messages": 5, + "max_tokens": 100000, # Custom token limit + }, +) +``` + +### Manual Transform Application + +```python +from swarms.structs.transforms import MessageTransforms, TransformConfig + +# Create transform instance +config = TransformConfig( + enabled=True, + model_name="gpt-4", + preserve_recent_messages=2 +) +transforms = MessageTransforms(config) + +# Apply to message list +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + # ... many messages ... + {"role": "user", "content": "What's the weather?"}, +] + +result = transforms.transform_messages(messages) +print(f"Compressed {result.original_token_count} -> {result.compressed_token_count} tokens") +``` + +## Compression Strategy + +### Middle-Out Algorithm + +The middle-out compression strategy works as follows: + +1. **Preserve System Messages**: Always keep system messages at the beginning +2. **Preserve Recent Messages**: Keep the most recent N messages +3. **Compress Middle**: Apply compression to messages in the middle of the conversation +4. **Maintain Context Flow**: Ensure the compressed conversation still makes contextual sense + +### Token vs Message Limits + +The system handles two types of limits: + +- **Token Limits**: Based on model's context window (e.g., GPT-4: 8K, Claude: 200K) +- **Message Limits**: Some models limit total message count (e.g., Claude: 1000 messages) + +### Smart Model Detection + +Built-in knowledge of popular models: + +```python +# Supported models include: +"gpt-4": 8192 tokens +"gpt-4-turbo": 128000 tokens +"gpt-4o": 128000 tokens +"claude-3-opus": 200000 tokens +"claude-3-sonnet": 200000 tokens +# ... and many more +``` + +## Advanced Configuration + +### Custom Token Limits + +```python +# Override default model limits +config = TransformConfig( + enabled=True, + model_name="gpt-4", + max_tokens=50000, # Custom limit instead of default 8192 + max_messages=500, # Custom message limit +) +``` + +### System Message Preservation + +```python +# Fine-tune what gets preserved +config = TransformConfig( + enabled=True, + preserve_system_messages=True, # Keep all system messages + preserve_recent_messages=5, # Keep last 5 messages +) +``` + +## Helper Functions + +### Quick Setup + +```python +from swarms.structs.transforms import create_default_transforms + +# Create with sensible defaults +transforms = create_default_transforms( + enabled=True, + model_name="claude-3-sonnet" +) +``` + +### Direct Application + +```python +from swarms.structs.transforms import apply_transforms_to_messages + +# Apply transforms to messages directly +result = apply_transforms_to_messages( + messages=my_messages, + model_name="gpt-4o" +) +``` + +## Integration with Agent Memory + +Transforms work seamlessly with conversation memory systems: + +```python +# Transforms integrate with conversation history +def handle_transforms( + transforms: MessageTransforms, + short_memory: Conversation, + model_name: str = "gpt-4o" +) -> str: + """Apply transforms to conversation memory.""" + messages = short_memory.return_messages_as_dictionary() + result = transforms.transform_messages(messages, model_name) + + if result.was_compressed: + logger.info(f"Compressed conversation: {result.compression_ratio:.2f} ratio") + + return result.messages +``` + +## Best Practices + +### When to Use Transforms + +- **Long Conversations**: When conversations exceed model context limits +- **Memory-Intensive Tasks**: Research, analysis, or multi-turn reasoning +- **Production Systems**: Where conversation length is unpredictable +- **Cost Optimization**: Reducing token usage for long conversations + +### Configuration Guidelines + +- **Start Simple**: Use defaults, then customize based on needs +- **Monitor Compression**: Check logs for compression ratios and effectiveness +- **Preserve Context**: Keep enough recent messages for continuity +- **Test Thoroughly**: Verify compressed conversations maintain quality + +### Performance Considerations + +- **Token Counting**: Uses efficient tokenization libraries +- **Memory Efficient**: Processes messages in-place when possible +- **Logging Overhead**: Compression stats are logged only when compression occurs +- **Model Compatibility**: Works with any model that has known limits + +## Troubleshooting + +### Common Issues + +**Transforms not activating:** +- Check that `enabled=True` in configuration +- Verify model name matches supported models +- Ensure message count/token count exceeds thresholds + +**Poor compression quality:** +- Increase `preserve_recent_messages` +- Ensure system messages are preserved +- Check compression ratios in logs + +**Unexpected behavior:** +- Review configuration parameters +- Check model-specific limits +- Examine conversation structure + +## API Reference + +### TransformConfig Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `enabled` | `bool` | `False` | Enable/disable transforms | +| `method` | `str` | `"middle-out"` | Compression method | +| `max_tokens` | `Optional[int]` | `None` | Custom token limit | +| `max_messages` | `Optional[int]` | `None` | Custom message limit | +| `model_name` | `str` | `"gpt-4"` | Target model name | +| `preserve_system_messages` | `bool` | `True` | Preserve system messages | +| `preserve_recent_messages` | `int` | `2` | Recent messages to keep | + +### TransformResult Fields + +| Field | Type | Description | +|-------|------|-------------| +| `messages` | `List[Dict]` | Transformed messages | +| `original_token_count` | `int` | Original token count | +| `compressed_token_count` | `int` | Compressed token count | +| `original_message_count` | `int` | Original message count | +| `compressed_message_count` | `int` | Compressed message count | +| `compression_ratio` | `float` | Compression ratio | +| `was_compressed` | `bool` | Whether compression occurred | + +## Examples in Action + +### Real-World Use Case: Research Agent + +```python +# Research agent that handles long document analysis +research_agent = Agent( + agent_name="Research-Agent", + model_name="claude-3-opus", + transforms=TransformConfig( + enabled=True, + model_name="claude-3-opus", + preserve_recent_messages=5, # Keep recent context for follow-ups + max_tokens=150000, # Leave room for responses + ), +) + +# Agent can now handle very long research conversations +# without hitting context limits +``` + +### Use Case: Customer Support Bot + +```python +# Support bot maintaining conversation history +support_agent = Agent( + agent_name="Support-Agent", + model_name="gpt-4o", + transforms=TransformConfig( + enabled=True, + preserve_system_messages=True, + preserve_recent_messages=10, # Keep recent conversation + max_messages=100, # Reasonable conversation length + ), +) +``` + +This comprehensive transforms system ensures that agents can handle conversations of any length while maintaining contextual coherence and optimal performance. diff --git a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_basic_demo.py b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_basic_demo.py index bf188a8e..a0738b5d 100644 --- a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_basic_demo.py +++ b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_basic_demo.py @@ -46,11 +46,8 @@ if __name__ == "__main__": print("\nExecuting with streaming callback:\n") # Run with streaming - result = swarm.run( - task=task, - streaming_callback=simple_callback - ) + result = swarm.run(task=task, streaming_callback=simple_callback) - print("\n" + "="*30) + print("\n" + "=" * 30) print("Final result:") print(result) diff --git a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_batch_demo.py b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_batch_demo.py index 0e4c71d4..badb2375 100644 --- a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_batch_demo.py +++ b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_batch_demo.py @@ -23,7 +23,11 @@ def create_batch_callback() -> Callable[[str, str, bool], None]: print(f"\nāœ… [{timestamp}] {agent_name} COMPLETED") else: # Shorter output for batch processing - print(f"šŸ”„ {agent_name}: {chunk[:30]}..." if len(chunk) > 30 else f"šŸ”„ {agent_name}: {chunk}") + print( + f"šŸ”„ {agent_name}: {chunk[:30]}..." + if len(chunk) > 30 + else f"šŸ”„ {agent_name}: {chunk}" + ) return batch_callback @@ -50,7 +54,7 @@ def create_agents(): if __name__ == "__main__": print("šŸ“¦ HIERARCHICAL SWARM BATCH PROCESSING DEMO") - print("="*50) + print("=" * 50) # Create agents and swarm agents = create_agents() @@ -67,7 +71,7 @@ if __name__ == "__main__": tasks = [ "What are the environmental benefits of solar energy?", "How does wind power contribute to sustainable development?", - "What are the economic advantages of hydroelectric power?" + "What are the economic advantages of hydroelectric power?", ] print(f"Processing {len(tasks)} tasks:") @@ -87,7 +91,7 @@ if __name__ == "__main__": streaming_callback=streaming_callback, ) - print("\n" + "="*50) + print("\n" + "=" * 50) print("šŸŽ‰ BATCH PROCESSING COMPLETED!") print(f"Processed {len(results)} tasks") diff --git a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_comparison_demo.py b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_comparison_demo.py index 1a5980f7..d2ef65d1 100644 --- a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_comparison_demo.py +++ b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_comparison_demo.py @@ -61,22 +61,30 @@ def run_traditional_swarm(): print("\nResult:") if isinstance(result, dict): for key, value in result.items(): - print(f"{key}: {value[:200]}..." if len(str(value)) > 200 else f"{key}: {value}") + print( + f"{key}: {value[:200]}..." + if len(str(value)) > 200 + else f"{key}: {value}" + ) else: - print(result[:500] + "..." if len(str(result)) > 500 else result) + print( + result[:500] + "..." if len(str(result)) > 500 else result + ) def run_streaming_swarm(): """Run swarm with streaming callbacks.""" - import time - from typing import Callable def simple_callback(agent_name: str, chunk: str, is_final: bool): if chunk.strip(): if is_final: print(f"\nāœ… {agent_name} completed") else: - print(f"šŸ”„ {agent_name}: {chunk[:50]}..." if len(chunk) > 50 else f"šŸ”„ {agent_name}: {chunk}") + print( + f"šŸ”„ {agent_name}: {chunk[:50]}..." + if len(chunk) > 50 + else f"šŸ”„ {agent_name}: {chunk}" + ) print("\nšŸŽÆ STREAMING SWARM EXECUTION") print("-" * 50) @@ -95,22 +103,25 @@ def run_streaming_swarm(): print(f"Task: {task}") - result = swarm.run( - task=task, - streaming_callback=simple_callback - ) + result = swarm.run(task=task, streaming_callback=simple_callback) print("\nResult:") if isinstance(result, dict): for key, value in result.items(): - print(f"{key}: {value[:200]}..." if len(str(value)) > 200 else f"{key}: {value}") + print( + f"{key}: {value[:200]}..." + if len(str(value)) > 200 + else f"{key}: {value}" + ) else: - print(result[:500] + "..." if len(str(result)) > 500 else result) + print( + result[:500] + "..." if len(str(result)) > 500 else result + ) if __name__ == "__main__": print("šŸ”„ HIERARCHICAL SWARM COMPARISON DEMO") - print("="*50) + print("=" * 50) print("Comparing traditional vs streaming execution\n") # Run traditional first @@ -119,6 +130,6 @@ if __name__ == "__main__": # Run streaming second run_streaming_swarm() - print("\n" + "="*50) + print("\n" + "=" * 50) print("✨ Comparison complete!") print("Notice how streaming shows progress in real-time") diff --git a/hierarchical_swarm_streaming_demo.py b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_demo.py similarity index 79% rename from hierarchical_swarm_streaming_demo.py rename to examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_demo.py index 78b0b6cb..fa65fd60 100644 --- a/hierarchical_swarm_streaming_demo.py +++ b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_demo.py @@ -11,7 +11,9 @@ def create_streaming_callback() -> Callable[[str, str, bool], None]: agent_buffers = {} paragraph_count = {} - def streaming_callback(agent_name: str, chunk: str, is_final: bool): + def streaming_callback( + agent_name: str, chunk: str, is_final: bool + ): timestamp = time.strftime("%H:%M:%S") # Initialize buffers for new agents @@ -19,19 +21,21 @@ def create_streaming_callback() -> Callable[[str, str, bool], None]: agent_buffers[agent_name] = "" paragraph_count[agent_name] = 1 print(f"\nšŸŽ¬ [{timestamp}] {agent_name} starting...") - print("="*60) + print("=" * 60) if chunk.strip(): # Split chunk into tokens (words/punctuation) - tokens = chunk.replace('\n', ' \n ').split() + tokens = chunk.replace("\n", " \n ").split() for token in tokens: # Handle paragraph breaks - if token == '\n': + if token == "\n": if agent_buffers[agent_name].strip(): - print(f"\nšŸ“„ [{timestamp}] {agent_name} - Paragraph {paragraph_count[agent_name]} Complete:") + print( + f"\nšŸ“„ [{timestamp}] {agent_name} - Paragraph {paragraph_count[agent_name]} Complete:" + ) print(f"{agent_buffers[agent_name].strip()}") - print("="*60) + print("=" * 60) paragraph_count[agent_name] += 1 agent_buffers[agent_name] = "" else: @@ -39,19 +43,29 @@ def create_streaming_callback() -> Callable[[str, str, bool], None]: agent_buffers[agent_name] += token + " " # Clear line and show current paragraph - print(f"\r[{timestamp}] {agent_name} | {agent_buffers[agent_name].strip()}", end="", flush=True) + print( + f"\r[{timestamp}] {agent_name} | {agent_buffers[agent_name].strip()}", + end="", + flush=True, + ) if is_final: print() # New line after live updates # Print any remaining content as final paragraph if agent_buffers[agent_name].strip(): - print(f"\nāœ… [{timestamp}] {agent_name} COMPLETED - Final Paragraph:") + print( + f"\nāœ… [{timestamp}] {agent_name} COMPLETED - Final Paragraph:" + ) print(f"{agent_buffers[agent_name].strip()}") print() - print(f"šŸŽÆ [{timestamp}] {agent_name} finished processing") - print(f"šŸ“Š Total paragraphs processed: {paragraph_count[agent_name] - 1}") - print("="*60) + print( + f"šŸŽÆ [{timestamp}] {agent_name} finished processing" + ) + print( + f"šŸ“Š Total paragraphs processed: {paragraph_count[agent_name] - 1}" + ) + print("=" * 60) return streaming_callback @@ -88,7 +102,7 @@ def create_agents(): if __name__ == "__main__": print("šŸŽÆ HIERARCHICAL SWARM STREAMING DEMO") - print("="*50) + print("=" * 50) # Create agents and swarm agents = create_agents() diff --git a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_example.py b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_example.py index 41c23da0..2dcd941d 100644 --- a/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_example.py +++ b/examples/multi_agent/hiearchical_swarm/hierarchical_swarm_streaming_example.py @@ -33,7 +33,7 @@ def streaming_callback(agent_name: str, chunk: str, is_final: bool): timestamp = time.strftime("%H:%M:%S") # Store accumulated text for each agent to track paragraph formation - if not hasattr(streaming_callback, 'agent_buffers'): + if not hasattr(streaming_callback, "agent_buffers"): streaming_callback.agent_buffers = {} streaming_callback.paragraph_count = {} @@ -42,39 +42,59 @@ def streaming_callback(agent_name: str, chunk: str, is_final: bool): streaming_callback.agent_buffers[agent_name] = "" streaming_callback.paragraph_count[agent_name] = 1 print(f"\nšŸŽ¬ [{timestamp}] {agent_name} starting...") - print("="*60) + print("=" * 60) if chunk.strip(): # Split chunk into tokens (words/punctuation) - tokens = chunk.replace('\n', ' \n ').split() + tokens = chunk.replace("\n", " \n ").split() for token in tokens: # Handle paragraph breaks - if token == '\n': - if streaming_callback.agent_buffers[agent_name].strip(): - print(f"\nšŸ“„ [{timestamp}] {agent_name} - Paragraph {streaming_callback.paragraph_count[agent_name]} Complete:") - print(f"{streaming_callback.agent_buffers[agent_name].strip()}") - print("="*60) - streaming_callback.paragraph_count[agent_name] += 1 + if token == "\n": + if streaming_callback.agent_buffers[ + agent_name + ].strip(): + print( + f"\nšŸ“„ [{timestamp}] {agent_name} - Paragraph {streaming_callback.paragraph_count[agent_name]} Complete:" + ) + print( + f"{streaming_callback.agent_buffers[agent_name].strip()}" + ) + print("=" * 60) + streaming_callback.paragraph_count[ + agent_name + ] += 1 streaming_callback.agent_buffers[agent_name] = "" else: # Add token to buffer and show live accumulation - streaming_callback.agent_buffers[agent_name] += token + " " + streaming_callback.agent_buffers[agent_name] += ( + token + " " + ) # Clear line and show current paragraph - print(f"\r[{timestamp}] {agent_name} | {streaming_callback.agent_buffers[agent_name].strip()}", end="", flush=True) + print( + f"\r[{timestamp}] {agent_name} | {streaming_callback.agent_buffers[agent_name].strip()}", + end="", + flush=True, + ) if is_final: print() # New line after live updates # Print any remaining content as final paragraph if streaming_callback.agent_buffers[agent_name].strip(): - print(f"\nāœ… [{timestamp}] {agent_name} COMPLETED - Final Paragraph:") - print(f"{streaming_callback.agent_buffers[agent_name].strip()}") + print( + f"\nāœ… [{timestamp}] {agent_name} COMPLETED - Final Paragraph:" + ) + print( + f"{streaming_callback.agent_buffers[agent_name].strip()}" + ) print() print(f"šŸŽÆ [{timestamp}] {agent_name} finished processing") - print(f"šŸ“Š Total paragraphs processed: {streaming_callback.paragraph_count[agent_name] - 1}") - print("="*60) + print( + f"šŸ“Š Total paragraphs processed: {streaming_callback.paragraph_count[agent_name] - 1}" + ) + print("=" * 60) def create_sample_agents(): @@ -141,15 +161,18 @@ def main(): """ print(f"šŸ“‹ Task: {task.strip()}") - print("\nšŸŽÆ Starting hierarchical swarm with live paragraph streaming...") + print( + "\nšŸŽÆ Starting hierarchical swarm with live paragraph streaming..." + ) print("Watch as agents build complete paragraphs in real-time!\n") - print("Each token accumulates to form readable text, showing the full paragraph as it builds.\n") + print( + "Each token accumulates to form readable text, showing the full paragraph as it builds.\n" + ) # Run the swarm with streaming callback try: result = swarm.run( - task=task, - streaming_callback=streaming_callback + task=task, streaming_callback=streaming_callback ) print("\nšŸŽ‰ Swarm execution completed!") @@ -168,7 +191,7 @@ def simple_callback_example(): def simple_callback(agent_name: str, chunk: str, is_final: bool): """Simple callback that shows live paragraph formation.""" - if not hasattr(simple_callback, 'buffer'): + if not hasattr(simple_callback, "buffer"): simple_callback.buffer = {} simple_callback.token_count = {} @@ -177,18 +200,26 @@ def simple_callback_example(): simple_callback.token_count[agent_name] = 0 if chunk.strip(): - tokens = chunk.replace('\n', ' \n ').split() + tokens = chunk.replace("\n", " \n ").split() for token in tokens: if token.strip(): simple_callback.token_count[agent_name] += 1 simple_callback.buffer[agent_name] += token + " " # Show live accumulation - print(f"\r{agent_name} | {simple_callback.buffer[agent_name].strip()}", end="", flush=True) + print( + f"\r{agent_name} | {simple_callback.buffer[agent_name].strip()}", + end="", + flush=True, + ) if is_final: print() # New line after live updates - print(f"āœ“ {agent_name} finished! Total tokens: {simple_callback.token_count[agent_name]}") - print(f"Final text: {simple_callback.buffer[agent_name].strip()}") + print( + f"āœ“ {agent_name} finished! Total tokens: {simple_callback.token_count[agent_name]}" + ) + print( + f"Final text: {simple_callback.buffer[agent_name].strip()}" + ) print("-" * 40) # Create simple agents diff --git a/examples/single_agent/utils/transform_prompts/transforms_agent_example.py b/examples/single_agent/utils/transform_prompts/transforms_agent_example.py new file mode 100644 index 00000000..470c1e34 --- /dev/null +++ b/examples/single_agent/utils/transform_prompts/transforms_agent_example.py @@ -0,0 +1,50 @@ +from swarms import Agent +from swarms.structs.transforms import TransformConfig + +# Initialize the agent with message transforms enabled +# This will automatically handle context size limits using middle-out compression +agent = Agent( + agent_name="Quantitative-Trading-Agent", + agent_description="Advanced quantitative trading and algorithmic analysis agent", + model_name="claude-sonnet-4-20250514", + dynamic_temperature_enabled=True, + max_loops=1, + dynamic_context_window=True, + streaming_on=False, + print_on=False, + # Enable message transforms for handling context limits + transforms=TransformConfig( + enabled=True, + method="middle-out", + model_name="claude-sonnet-4-20250514", + preserve_system_messages=True, + preserve_recent_messages=2, + ), +) + +# Alternative way to configure transforms using dictionary +# agent_with_dict_transforms = Agent( +# agent_name="Trading-Agent-Dict", +# model_name="gpt-4o", +# max_loops=1, +# transforms={ +# "enabled": True, +# "method": "middle-out", +# "model_name": "gpt-4o", +# "preserve_system_messages": True, +# "preserve_recent_messages": 3, +# }, +# ) + +out = agent.run( + task="What are the top five best energy stocks across nuclear, solar, gas, and other energy sources?", +) + +print(out) + +# The transforms feature provides: +# 1. Automatic context size management for models with token limits +# 2. Message count management for models like Claude with 1000 message limits +# 3. Middle-out compression that preserves important context (beginning and recent messages) +# 4. Smart model selection based on context requirements +# 5. Detailed logging of compression statistics diff --git a/examples/single_agent/utils/transform_prompts/transforms_examples.py b/examples/single_agent/utils/transform_prompts/transforms_examples.py new file mode 100644 index 00000000..ef0bd092 --- /dev/null +++ b/examples/single_agent/utils/transform_prompts/transforms_examples.py @@ -0,0 +1,420 @@ +""" +Message Transforms Examples for Swarms + +This file demonstrates various ways to use the message transforms feature +for handling context size limitations and message count restrictions in LLMs. +""" + +from swarms import Agent +from swarms.structs.transforms import ( + MessageTransforms, + TransformConfig, + TransformResult, + create_default_transforms, + apply_transforms_to_messages, +) + +from loguru import logger + + +def example_1_basic_transforms(): + """Example 1: Basic transforms usage with TransformConfig""" + print("\n" + "=" * 60) + print("Example 1: Basic Transforms with TransformConfig") + print("=" * 60) + + # Create sample messages that exceed typical context limits + sample_messages = [ + { + "role": "system", + "content": "You are a helpful AI assistant specializing in software development.", + }, + { + "role": "user", + "content": "Explain the concept of object-oriented programming.", + }, + { + "role": "assistant", + "content": "Object-oriented programming (OOP) is a programming paradigm based on the concept of 'objects', which can contain data and code. The data is in the form of fields (often called attributes or properties), and the code is in the form of procedures (often called methods). OOP focuses on the objects that developers want to manipulate rather than the logic required to manipulate them.", + }, + { + "role": "user", + "content": "Can you give me a simple example in Python?", + }, + { + "role": "assistant", + "content": "Here's a simple example of a class in Python:\n\n```python\nclass Dog:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n\n def bark(self):\n return f'{self.name} says woof!'\n\n# Usage\nmy_dog = Dog('Buddy', 3)\nprint(my_dog.bark()) # Output: Buddy says woof!\n```\n\nThis example shows:\n1. A `Dog` class with attributes `name` and `age`\n2. An `__init__` method (constructor)\n3. A `bark` method\n4. How to create an instance and call methods", + }, + { + "role": "user", + "content": "What are the four main principles of OOP?", + }, + { + "role": "assistant", + "content": "The four main principles of Object-Oriented Programming are:\n\n1. **Encapsulation**: Bundling data and methods that operate on that data within a single unit (class). It restricts direct access to some of an object's components.\n\n2. **Abstraction**: Hiding complex implementation details and showing only the necessary features of an object. This simplifies the interface.\n\n3. **Inheritance**: Creating new classes from existing classes, allowing code reuse and establishing a relationship between parent and child classes.\n\n4. **Polymorphism**: The ability of different objects to respond to the same method call in different ways. This allows for more flexible and extensible code.", + }, + { + "role": "user", + "content": "How does inheritance work in practice?", + }, + { + "role": "assistant", + "content": "Let me show you a practical example of inheritance:\n\n```python\nclass Animal:\n def __init__(self, name):\n self.name = name\n\n def speak(self):\n pass # Abstract method\n\nclass Dog(Animal):\n def speak(self):\n return f'{self.name} says Woof!'\n\nclass Cat(Animal):\n def speak(self):\n return f'{self.name} says Meow!'\n\n# Usage\ndog = Dog('Buddy')\ncat = Cat('Whiskers')\nprint(dog.speak()) # Buddy says Woof!\nprint(cat.speak()) # Whiskers says Meow!\n```\n\nHere, `Dog` and `Cat` inherit from `Animal`. They both have a `speak` method, but each implements it differently - this is polymorphism in action!", + }, + { + "role": "user", + "content": "What are some real-world applications of OOP?", + }, + { + "role": "assistant", + "content": "Object-oriented programming is used extensively in real-world applications:\n\n1. **GUI Applications**: Frameworks like Qt, Swing, and WPF use OOP to create user interfaces with windows, buttons, and other components as objects.\n\n2. **Game Development**: Game engines like Unity and Unreal Engine heavily use OOP. Characters, weapons, environments, and game logic are all implemented as objects.\n\n3. **Web Applications**: Frameworks like Django (Python), Ruby on Rails, and ASP.NET use OOP patterns for handling requests, managing databases, and organizing code.\n\n4. **Database Systems**: Object-Relational Mapping (ORM) systems like SQLAlchemy, Hibernate, and Entity Framework map database tables to objects.\n\n5. **Operating Systems**: Modern operating systems use OOP for managing processes, files, network connections, and hardware devices.\n\n6. **Mobile Applications**: iOS (Objective-C/Swift) and Android (Java/Kotlin) development heavily relies on OOP concepts.\n\n7. **Scientific Computing**: Libraries like NumPy and SciPy, while not purely OOP, use objects to represent complex mathematical structures.\n\n8. **Financial Systems**: Banking software, trading platforms, and financial modeling tools use OOP to represent accounts, transactions, and complex financial instruments.\n\nThe benefits of OOP in these applications include:\n- Code reusability\n- Easier maintenance\n- Better organization\n- Encapsulation of complex logic\n- Inheritance for building upon existing functionality", + }, + ] + + # Create transforms configuration + config = TransformConfig( + enabled=True, + method="middle-out", + model_name="gpt-4", + preserve_system_messages=True, + preserve_recent_messages=3, + max_tokens=4000, # Force compression for demonstration + ) + + # Create transforms instance + transforms = MessageTransforms(config) + + # Apply transforms + result: TransformResult = transforms.transform_messages( + sample_messages + ) + + print(f"Original messages: {result.original_message_count}") + print(f"Compressed messages: {result.compressed_message_count}") + print(f"Original tokens: {result.original_token_count}") + print(f"Compressed tokens: {result.compressed_token_count}") + print(".2f") + + if result.was_compressed: + print("\nTransformed messages:") + for i, msg in enumerate(result.messages, 1): + print( + f"{i}. {msg['role']}: {msg['content'][:100]}{'...' if len(msg['content']) > 100 else ''}" + ) + else: + print("No compression was needed.") + + +def example_2_dictionary_config(): + """Example 2: Using dictionary configuration""" + print("\n" + "=" * 60) + print("Example 2: Dictionary Configuration") + print("=" * 60) + + # Create transforms using dictionary (alternative to TransformConfig) + dict_config = { + "enabled": True, + "method": "middle-out", + "model_name": "claude-3-sonnet", + "preserve_system_messages": True, + "preserve_recent_messages": 2, + "max_messages": 5, # Force message count compression + } + + config = TransformConfig(**dict_config) + transforms = MessageTransforms(config) + + # Sample messages + messages = [ + {"role": "system", "content": "You are a coding assistant."}, + { + "role": "user", + "content": "Help me debug this Python code.", + }, + { + "role": "assistant", + "content": "I'd be happy to help! Please share your Python code and describe the issue you're experiencing.", + }, + { + "role": "user", + "content": "Here's my code: def factorial(n): if n == 0: return 1 else: return n * factorial(n-1)", + }, + { + "role": "assistant", + "content": "Your factorial function looks correct! It's a classic recursive implementation. However, it doesn't handle negative numbers. Let me suggest an improved version...", + }, + { + "role": "user", + "content": "It works for positive numbers but crashes for large n due to recursion depth.", + }, + { + "role": "assistant", + "content": "Ah, that's a common issue with recursive factorial functions. Python has a default recursion limit of 1000. For large numbers, you should use an iterative approach instead...", + }, + { + "role": "user", + "content": "Can you show me the iterative version?", + }, + ] + + result = transforms.transform_messages(messages) + + print(f"Original messages: {result.original_message_count}") + print(f"Compressed messages: {result.compressed_message_count}") + print(f"Compression applied: {result.was_compressed}") + + if result.was_compressed: + print("\nCompressed conversation:") + for msg in result.messages: + print( + f"{msg['role'].title()}: {msg['content'][:80]}{'...' if len(msg['content']) > 80 else ''}" + ) + + +def example_3_agent_integration(): + """Example 3: Integration with Agent class""" + print("\n" + "=" * 60) + print("Example 3: Agent Integration") + print("=" * 60) + + # Create agent with transforms enabled + agent = Agent( + agent_name="Transformed-Agent", + agent_description="AI assistant with automatic context management", + model_name="gpt-4o", + max_loops=1, + streaming_on=False, + print_on=False, + # Enable transforms + transforms=TransformConfig( + enabled=True, + method="middle-out", + model_name="gpt-4o", + preserve_system_messages=True, + preserve_recent_messages=3, + ), + ) + + print("Agent created with transforms enabled.") + print( + "The agent will automatically apply message transforms when context limits are approached." + ) + + # You can also check if transforms are active + if agent.transforms is not None: + print("āœ“ Transforms are active on this agent") + print(f" Method: {agent.transforms.config.method}") + print(f" Model: {agent.transforms.config.model_name}") + print( + f" Preserve recent: {agent.transforms.config.preserve_recent_messages}" + ) + else: + print("āœ— No transforms configured") + + +def example_4_convenience_function(): + """Example 4: Using convenience functions""" + print("\n" + "=" * 60) + print("Example 4: Convenience Functions") + print("=" * 60) + + # Sample messages + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "Tell me about machine learning.", + }, + { + "role": "assistant", + "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed. It involves algorithms that can identify patterns in data and make predictions or decisions.", + }, + {"role": "user", "content": "What are the main types?"}, + { + "role": "assistant", + "content": "There are three main types of machine learning:\n\n1. **Supervised Learning**: The algorithm learns from labeled training data. Examples include classification and regression tasks.\n\n2. **Unsupervised Learning**: The algorithm finds patterns in unlabeled data. Examples include clustering and dimensionality reduction.\n\n3. **Reinforcement Learning**: The algorithm learns through trial and error by interacting with an environment. Examples include game playing and robotic control.", + }, + {"role": "user", "content": "Can you give examples of each?"}, + ] + + # Method 1: Using create_default_transforms + print("Method 1: create_default_transforms") + transforms = create_default_transforms( + enabled=True, + model_name="gpt-3.5-turbo", + ) + result1 = transforms.transform_messages(messages) + print( + f"Default transforms - Original: {result1.original_message_count}, Compressed: {result1.compressed_message_count}" + ) + + # Method 2: Using apply_transforms_to_messages directly + print("\nMethod 2: apply_transforms_to_messages") + config = TransformConfig( + enabled=True, max_tokens=1000 + ) # Force compression + result2 = apply_transforms_to_messages(messages, config, "gpt-4") + print( + f"Direct function - Original tokens: {result2.original_token_count}, Compressed tokens: {result2.compressed_token_count}" + ) + + +def example_5_advanced_scenarios(): + """Example 5: Advanced compression scenarios""" + print("\n" + "=" * 60) + print("Example 5: Advanced Scenarios") + print("=" * 60) + + # Scenario 1: Very long conversation with many messages + print("Scenario 1: Long conversation (100+ messages)") + long_messages = [] + for i in range(150): # Create 150 messages + role = "user" if i % 2 == 0 else "assistant" + content = f"Message {i+1}: {' '.join([f'word{j}' for j in range(20)])}" # Make each message longer + long_messages.append({"role": role, "content": content}) + + # Add system message at the beginning + long_messages.insert( + 0, + { + "role": "system", + "content": "You are a helpful assistant in a very long conversation.", + }, + ) + + config = TransformConfig( + enabled=True, + max_messages=20, # Very restrictive limit + preserve_system_messages=True, + preserve_recent_messages=5, + ) + transforms = MessageTransforms(config) + result = transforms.transform_messages(long_messages) + + print( + f"Long conversation: {result.original_message_count} -> {result.compressed_message_count} messages" + ) + print( + f"Token reduction: {result.original_token_count} -> {result.compressed_token_count}" + ) + + # Scenario 2: Token-heavy messages + print("\nScenario 2: Token-heavy content") + token_heavy_messages = [ + {"role": "system", "content": "You are analyzing code."}, + { + "role": "user", + "content": "Analyze this Python file: " + "x = 1\n" * 500, + }, # Very long code + { + "role": "assistant", + "content": "This appears to be a Python file that repeatedly assigns 1 to variable x. " + * 100, + }, + ] + + config = TransformConfig( + enabled=True, + max_tokens=2000, # Restrictive token limit + preserve_system_messages=True, + ) + result = transforms.transform_messages(token_heavy_messages) + print( + f"Token-heavy content: {result.original_token_count} -> {result.compressed_token_count} tokens" + ) + + # Scenario 3: Mixed content types + print("\nScenario 3: Mixed message types") + mixed_messages = [ + { + "role": "system", + "content": "You handle various content types.", + }, + { + "role": "user", + "content": "Process this data: [1, 2, 3, 4, 5] * 50", + }, # List-like content + { + "role": "assistant", + "content": "I've processed your list data.", + }, + { + "role": "user", + "content": "Now process this dict: {'key': 'value'} * 30", + }, # Dict-like content + { + "role": "assistant", + "content": "Dictionary processed successfully.", + }, + ] + + result = transforms.transform_messages(mixed_messages) + print( + f"Mixed content: {result.original_message_count} -> {result.compressed_message_count} messages" + ) + + +def example_6_model_specific_limits(): + """Example 6: Model-specific context limits""" + print("\n" + "=" * 60) + print("Example 6: Model-Specific Limits") + print("=" * 60) + + # Test different models and their limits + models_and_limits = [ + ("gpt-4", 8192), + ("gpt-4-turbo", 128000), + ("claude-3-sonnet", 200000), + ("claude-2", 100000), + ("gpt-3.5-turbo", 16385), + ] + + sample_content = "This is a sample message. " * 100 # ~300 tokens + messages = [ + {"role": "user", "content": sample_content} for _ in range(10) + ] + + for model, expected_limit in models_and_limits: + config = TransformConfig( + enabled=True, + model_name=model, + preserve_system_messages=False, + ) + transforms = MessageTransforms(config) + result = transforms.transform_messages(messages) + + print( + f"{model}: {result.original_token_count} -> {result.compressed_token_count} tokens (limit: {expected_limit})" + ) + + +def main(): + """Run all examples""" + print("šŸš€ Swarms Message Transforms Examples") + print("=" * 60) + + try: + example_1_basic_transforms() + example_2_dictionary_config() + example_3_agent_integration() + example_4_convenience_function() + example_5_advanced_scenarios() + example_6_model_specific_limits() + + print("\n" + "=" * 60) + print("āœ… All examples completed successfully!") + print("=" * 60) + print("\nKey takeaways:") + print("• Transforms automatically handle context size limits") + print("• Middle-out compression preserves important context") + print("• System messages and recent messages are prioritized") + print("• Works with any LLM model through the Agent class") + print("• Detailed logging shows compression statistics") + + except Exception as e: + logger.error(f"Error running examples: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/swarms/sims/senator_assembly.py b/swarms/sims/senator_assembly.py index bff4ea42..bae16060 100644 --- a/swarms/sims/senator_assembly.py +++ b/swarms/sims/senator_assembly.py @@ -1,3 +1,9 @@ +""" + +Senator Assembly: A Large-Scale Multi-Agent Simulation of the US Senate + +""" + from functools import lru_cache from typing import Dict, List, Optional diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 00f8dc60..578f42a6 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -57,6 +57,11 @@ from swarms.schemas.mcp_schemas import ( from swarms.structs.agent_roles import agent_roles from swarms.structs.conversation import Conversation from swarms.structs.ma_utils import set_random_models_for_agents +from swarms.structs.transforms import ( + MessageTransforms, + TransformConfig, + handle_transforms, +) from swarms.structs.safe_loading import ( SafeLoaderUtils, SafeStateManager, @@ -188,6 +193,7 @@ class Agent: saved_state_path (str): The path to the saved state autosave (bool): Autosave the state context_length (int): The context length + transforms (Optional[Union[TransformConfig, dict]]): Message transformation configuration for handling context limits user_name (str): The user name self_healing_enabled (bool): Enable self healing code_interpreter (bool): Enable code interpreter @@ -324,6 +330,7 @@ class Agent: saved_state_path: Optional[str] = None, autosave: Optional[bool] = False, context_length: Optional[int] = 8192, + transforms: Optional[Union[TransformConfig, dict]] = None, user_name: Optional[str] = "Human", self_healing_enabled: Optional[bool] = False, code_interpreter: Optional[bool] = False, @@ -459,6 +466,20 @@ class Agent: self.dynamic_loops = dynamic_loops self.user_name = user_name self.context_length = context_length + + # Initialize transforms + if transforms is None: + self.transforms = None + elif isinstance(transforms, TransformConfig): + self.transforms = MessageTransforms(transforms) + elif isinstance(transforms, dict): + config = TransformConfig(**transforms) + self.transforms = MessageTransforms(config) + else: + raise ValueError( + "transforms must be a TransformConfig object or a dictionary" + ) + self.sop = sop self.sop_list = sop_list self.tools = tools @@ -1162,10 +1183,19 @@ class Agent: if self.dynamic_temperature_enabled is True: self.dynamic_temperature() - # Task prompt - task_prompt = ( - self.short_memory.return_history_as_string() - ) + # Task prompt with optional transforms + if self.transforms is not None: + task_prompt = handle_transforms( + transforms=self.transforms, + short_memory=self.short_memory, + model_name=self.model_name, + ) + + else: + # Use original method if no transforms + task_prompt = ( + self.short_memory.return_history_as_string() + ) # Parameters attempt = 0 diff --git a/swarms/structs/hiearchical_swarm.py b/swarms/structs/hiearchical_swarm.py index abf5cfdf..40461b1f 100644 --- a/swarms/structs/hiearchical_swarm.py +++ b/swarms/structs/hiearchical_swarm.py @@ -743,11 +743,10 @@ class HierarchicalSwarm: self.multi_agent_prompt_improvements = ( multi_agent_prompt_improvements ) - - self.reliability_checks() + self.initialize_swarm() - def reliability_checks(self): + def initialize_swarm(self): if self.interactive: self.agents_no_print() @@ -767,7 +766,7 @@ class HierarchicalSwarm: ) self.init_swarm() - + def list_worker_agents(self) -> str: return list_all_agents( agents=self.agents, @@ -798,7 +797,7 @@ class HierarchicalSwarm: Returns: str: The reasoning output from the agent """ - + agent = Agent( agent_name=self.director_name, agent_description=f"You're the {self.director_name} agent that is responsible for reasoning about the task and creating a plan for the swarm to accomplish the task.", @@ -1044,9 +1043,16 @@ class HierarchicalSwarm: logger.error(error_msg) raise e - def step(self, task: str, img: str = None, streaming_callback: Optional[ - Callable[[str, str, bool], None] - ] = None, *args, **kwargs): + def step( + self, + task: str, + img: str = None, + streaming_callback: Optional[ + Callable[[str, str, bool], None] + ] = None, + *args, + **kwargs, + ): """ Execute a single step of the hierarchical swarm workflow. @@ -1106,7 +1112,9 @@ class HierarchicalSwarm: self.dashboard.update_director_status("EXECUTING") # Execute the orders - outputs = self.execute_orders(orders, streaming_callback=streaming_callback) + outputs = self.execute_orders( + orders, streaming_callback=streaming_callback + ) if self.verbose: logger.info(f"[EXEC] Executed {len(outputs)} orders") @@ -1213,7 +1221,11 @@ class HierarchicalSwarm: # Execute one step of the swarm try: last_output = self.step( - task=loop_task, img=img, streaming_callback=streaming_callback, *args, **kwargs + task=loop_task, + img=img, + streaming_callback=streaming_callback, + *args, + **kwargs, ) if self.verbose: @@ -1334,9 +1346,14 @@ class HierarchicalSwarm: logger.error(error_msg) def call_single_agent( - self, agent_name: str, task: str, streaming_callback: Optional[ + self, + agent_name: str, + task: str, + streaming_callback: Optional[ Callable[[str, str, bool], None] - ] = None, *args, **kwargs + ] = None, + *args, + **kwargs, ): """ Call a single agent by name to execute a specific task. @@ -1393,11 +1410,14 @@ class HierarchicalSwarm: # Handle streaming callback if provided if streaming_callback is not None: + def agent_streaming_callback(chunk: str): """Wrapper for agent streaming callback.""" try: if chunk is not None and chunk.strip(): - streaming_callback(agent_name, chunk, False) + streaming_callback( + agent_name, chunk, False + ) except Exception as callback_error: if self.verbose: logger.warning( @@ -1589,9 +1609,13 @@ class HierarchicalSwarm: logger.error(error_msg) raise e - def execute_orders(self, orders: list, streaming_callback: Optional[ - Callable[[str, str, bool], None] - ] = None): + def execute_orders( + self, + orders: list, + streaming_callback: Optional[ + Callable[[str, str, bool], None] + ] = None, + ): """ Execute all orders from the director's output. @@ -1632,7 +1656,9 @@ class HierarchicalSwarm: ) output = self.call_single_agent( - order.agent_name, order.task, streaming_callback=streaming_callback + order.agent_name, + order.task, + streaming_callback=streaming_callback, ) # Update dashboard with completed status @@ -1661,9 +1687,14 @@ class HierarchicalSwarm: logger.error(error_msg) def batched_run( - self, tasks: List[str], img: str = None, streaming_callback: Optional[ + self, + tasks: List[str], + img: str = None, + streaming_callback: Optional[ Callable[[str, str, bool], None] - ] = None, *args, **kwargs + ] = None, + *args, + **kwargs, ): """ Execute the hierarchical swarm for multiple tasks in sequence. @@ -1701,7 +1732,13 @@ class HierarchicalSwarm: # Process each task in parallel for task in tasks: - result = self.run(task, img, streaming_callback=streaming_callback, *args, **kwargs) + result = self.run( + task, + img, + streaming_callback=streaming_callback, + *args, + **kwargs, + ) results.append(result) if self.verbose: diff --git a/swarms/structs/transforms.py b/swarms/structs/transforms.py new file mode 100644 index 00000000..1a705531 --- /dev/null +++ b/swarms/structs/transforms.py @@ -0,0 +1,521 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger + +from swarms.utils.litellm_tokenizer import count_tokens +from swarms.structs.conversation import Conversation + + +@dataclass +class TransformConfig: + """Configuration for message transforms.""" + + enabled: bool = False + method: str = "middle-out" + max_tokens: Optional[int] = None + max_messages: Optional[int] = None + model_name: str = "gpt-4" + preserve_system_messages: bool = True + preserve_recent_messages: int = 2 + + +@dataclass +class TransformResult: + """Result of message transformation.""" + + messages: List[Dict[str, Any]] + original_token_count: int + compressed_token_count: int + original_message_count: int + compressed_message_count: int + compression_ratio: float + was_compressed: bool + + +class MessageTransforms: + """ + Handles message transformations for context size management. + + Supports middle-out compression which removes or truncates messages + from the middle of the conversation while preserving the beginning + and end, which are typically more important for context. + """ + + def __init__(self, config: TransformConfig): + """ + Initialize the MessageTransforms with configuration. + + Args: + config: TransformConfig object with transformation settings + """ + self.config = config + + def transform_messages( + self, + messages: List[Dict[str, Any]], + target_model: Optional[str] = None, + ) -> TransformResult: + """ + Transform messages according to the configured strategy. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys + target_model: Optional target model name to determine context limits + + Returns: + TransformResult containing transformed messages and metadata + """ + if not self.config.enabled or not messages: + return TransformResult( + messages=messages, + original_token_count=self._count_total_tokens( + messages + ), + compressed_token_count=self._count_total_tokens( + messages + ), + original_message_count=len(messages), + compressed_message_count=len(messages), + compression_ratio=1.0, + was_compressed=False, + ) + + # Use target model if provided, otherwise use config model + model_name = target_model or self.config.model_name + + # Get model context limits + max_tokens = self._get_model_context_limit(model_name) + max_messages = self._get_model_message_limit(model_name) + + # Override with config values if specified + if self.config.max_tokens is not None: + max_tokens = self.config.max_tokens + if self.config.max_messages is not None: + max_messages = self.config.max_messages + + original_tokens = self._count_total_tokens(messages) + original_messages = len(messages) + + transformed_messages = messages.copy() + + # Apply transformations + if max_messages and len(transformed_messages) > max_messages: + transformed_messages = self._compress_message_count( + transformed_messages, max_messages + ) + + if ( + max_tokens + and self._count_total_tokens(transformed_messages) + > max_tokens + ): + transformed_messages = self._compress_tokens( + transformed_messages, max_tokens + ) + + compressed_tokens = self._count_total_tokens( + transformed_messages + ) + compressed_messages = len(transformed_messages) + + compression_ratio = ( + compressed_tokens / original_tokens + if original_tokens > 0 + else 1.0 + ) + + return TransformResult( + messages=transformed_messages, + original_token_count=original_tokens, + compressed_token_count=compressed_tokens, + original_message_count=original_messages, + compressed_message_count=compressed_messages, + compression_ratio=compression_ratio, + was_compressed=compressed_tokens < original_tokens + or compressed_messages < original_messages, + ) + + def _compress_message_count( + self, messages: List[Dict[str, Any]], max_messages: int + ) -> List[Dict[str, Any]]: + """ + Compress message count using middle-out strategy. + + Args: + messages: List of messages to compress + max_messages: Maximum number of messages to keep + + Returns: + Compressed list of messages + """ + if len(messages) <= max_messages: + return messages + + # Always preserve system messages at the beginning + system_messages = [] + other_messages = [] + + for msg in messages: + if msg.get("role") == "system": + system_messages.append(msg) + else: + other_messages.append(msg) + + # Calculate how many non-system messages we can keep + available_slots = max_messages - len(system_messages) + if available_slots <= 0: + # If we can't fit any non-system messages, just return system messages + return system_messages[:max_messages] + + # Preserve recent messages + preserve_recent = min( + self.config.preserve_recent_messages, len(other_messages) + ) + recent_messages = ( + other_messages[-preserve_recent:] + if preserve_recent > 0 + else [] + ) + + # Calculate remaining slots for middle messages + remaining_slots = available_slots - len(recent_messages) + if remaining_slots <= 0: + # Only keep system messages and recent messages + result = system_messages + recent_messages + return result[:max_messages] + + # Get messages from the beginning (excluding recent ones) + early_messages = ( + other_messages[:-preserve_recent] + if preserve_recent > 0 + else other_messages + ) + + # If we have enough slots for all early messages + if len(early_messages) <= remaining_slots: + result = ( + system_messages + early_messages + recent_messages + ) + return result[:max_messages] + + # Apply middle-out compression to early messages + compressed_early = self._middle_out_compress( + early_messages, remaining_slots + ) + + result = system_messages + compressed_early + recent_messages + return result[:max_messages] + + def _compress_tokens( + self, messages: List[Dict[str, Any]], max_tokens: int + ) -> List[Dict[str, Any]]: + """ + Compress messages to fit within token limit using middle-out strategy. + + Args: + messages: List of messages to compress + max_tokens: Maximum token count + + Returns: + Compressed list of messages + """ + current_tokens = self._count_total_tokens(messages) + + if current_tokens <= max_tokens: + return messages + + # First try to compress message count if we have too many messages + if ( + len(messages) > 50 + ): # Arbitrary threshold for when to try message count compression first + messages = self._compress_message_count( + messages, len(messages) // 2 + ) + + current_tokens = self._count_total_tokens(messages) + if current_tokens <= max_tokens: + return messages + + # Apply middle-out compression with token awareness + return self._middle_out_compress_tokens(messages, max_tokens) + + def _middle_out_compress( + self, messages: List[Dict[str, Any]], target_count: int + ) -> List[Dict[str, Any]]: + """ + Apply middle-out compression to reduce message count. + + Args: + messages: Messages to compress + target_count: Target number of messages + + Returns: + Compressed messages + """ + if len(messages) <= target_count: + return messages + + # Keep first half and last half + keep_count = target_count // 2 + first_half = messages[:keep_count] + last_half = messages[-keep_count:] + + # Combine first half, last half, and if odd number, add the middle message + result = first_half + last_half + + if target_count % 2 == 1 and len(messages) > keep_count * 2: + middle_index = len(messages) // 2 + result.insert(keep_count, messages[middle_index]) + + return result[:target_count] + + def _middle_out_compress_tokens( + self, messages: List[Dict[str, Any]], max_tokens: int + ) -> List[Dict[str, Any]]: + """ + Apply middle-out compression with token awareness. + + Args: + messages: Messages to compress + max_tokens: Maximum token count + + Returns: + Compressed messages + """ + # Start by keeping all messages and remove from middle until under token limit + current_messages = messages.copy() + + while ( + self._count_total_tokens(current_messages) > max_tokens + and len(current_messages) > 2 + ): + # Remove from the middle + if len(current_messages) <= 2: + break + + # Find the middle message (avoiding system messages if possible) + middle_index = len(current_messages) // 2 + + # Try to avoid removing system messages + if current_messages[middle_index].get("role") == "system": + # Look for a non-system message near the middle + for offset in range( + 1, len(current_messages) // 4 + 1 + ): + if ( + middle_index - offset >= 0 + and current_messages[ + middle_index - offset + ].get("role") + != "system" + ): + middle_index = middle_index - offset + break + if ( + middle_index + offset < len(current_messages) + and current_messages[ + middle_index + offset + ].get("role") + != "system" + ): + middle_index = middle_index + offset + break + + # Remove the middle message + current_messages.pop(middle_index) + + return current_messages + + def _count_total_tokens( + self, messages: List[Dict[str, Any]] + ) -> int: + """Count total tokens in a list of messages.""" + total_tokens = 0 + for message in messages: + content = message.get("content", "") + if isinstance(content, str): + total_tokens += count_tokens( + content, self.config.model_name + ) + elif isinstance(content, (list, dict)): + # Handle structured content + total_tokens += count_tokens( + str(content), self.config.model_name + ) + return total_tokens + + def _get_model_context_limit( + self, model_name: str + ) -> Optional[int]: + """ + Get the context token limit for a given model. + + Args: + model_name: Name of the model + + Returns: + Token limit or None if unknown + """ + # Common model context limits (in tokens) + model_limits = { + "gpt-4": 8192, + "gpt-4-turbo": 128000, + "gpt-4o": 128000, + "gpt-4o-mini": 128000, + "gpt-3.5-turbo": 16385, + "claude-3-opus": 200000, + "claude-3-sonnet": 200000, + "claude-3-haiku": 200000, + "claude-3-5-sonnet": 200000, + "claude-2": 100000, + "gemini-pro": 32768, + "gemini-pro-vision": 16384, + "llama-2-7b": 4096, + "llama-2-13b": 4096, + "llama-2-70b": 4096, + } + + # Check for exact match first + if model_name in model_limits: + return model_limits[model_name] + + # Check for partial matches + for model_key, limit in model_limits.items(): + if model_key in model_name.lower(): + return limit + + # Default fallback + logger.warning( + f"Unknown model '{model_name}', using default context limit of 4096 tokens" + ) + return 4096 + + def _get_model_message_limit( + self, model_name: str + ) -> Optional[int]: + """ + Get the message count limit for a given model. + + Args: + model_name: Name of the model + + Returns: + Message limit or None if no limit + """ + # Models with known message limits + message_limits = { + "claude-3-opus": 1000, + "claude-3-sonnet": 1000, + "claude-3-haiku": 1000, + "claude-3-5-sonnet": 1000, + "claude-2": 1000, + } + + # Check for exact match first + if model_name in message_limits: + return message_limits[model_name] + + # Check for partial matches + for model_key, limit in message_limits.items(): + if model_key in model_name.lower(): + return limit + + return None # No known limit + + +def create_default_transforms( + enabled: bool = True, + method: str = "middle-out", + model_name: str = "gpt-4", +) -> MessageTransforms: + """ + Create MessageTransforms with default configuration. + + Args: + enabled: Whether transforms are enabled + method: Transform method to use + model_name: Model name for context limits + + Returns: + Configured MessageTransforms instance + """ + config = TransformConfig( + enabled=enabled, method=method, model_name=model_name + ) + return MessageTransforms(config) + + +def apply_transforms_to_messages( + messages: List[Dict[str, Any]], + transforms_config: Optional[TransformConfig] = None, + model_name: str = "gpt-4", +) -> TransformResult: + """ + Convenience function to apply transforms to messages. + + Args: + messages: List of message dictionaries + transforms_config: Optional transform configuration + model_name: Model name for context determination + + Returns: + TransformResult with processed messages + """ + if transforms_config is None: + transforms = create_default_transforms( + enabled=True, model_name=model_name + ) + else: + transforms = MessageTransforms(transforms_config) + + return transforms.transform_messages(messages, model_name) + + +def handle_transforms( + transforms: MessageTransforms, + short_memory: Conversation = None, + model_name: Optional[str] = "gpt-4o", +) -> str: + """ + Handle message transforms and return a formatted task prompt. + + Applies message transforms to the provided messages using the given + MessageTransforms instance. If compression occurs, logs the results. + Returns the formatted string of messages after transforms, or the + original message history as a string if no transforms are enabled. + + Args: + messages: List of message dictionaries to process. + transforms: MessageTransforms instance to apply. + short_memory: Object with methods to return messages as dictionary or string. + model_name: Name of the model for context. + + Returns: + Formatted string of messages for the task prompt. + """ + # Get messages as dictionary format for transforms + messages_dict = short_memory.return_messages_as_dictionary() + + # Apply transforms + transform_result = transforms.transform_messages( + messages_dict, model_name + ) + + # Log transform results if compression occurred + if transform_result.was_compressed: + logger.info( + f"Applied message transforms: {transform_result.original_message_count} -> " + f"{transform_result.compressed_message_count} messages, " + f"{transform_result.original_token_count} -> {transform_result.compressed_token_count} tokens " + f"(ratio: {transform_result.compression_ratio:.2f})" + ) + + # Convert transformed messages back to string format + formatted_messages = [ + f"{message['role']}: {message['content']}" + for message in transform_result.messages + ] + task_prompt = "\n\n".join(formatted_messages) + + return task_prompt