You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
522 lines
16 KiB
522 lines
16 KiB
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from loguru import logger
|
|
|
|
from swarms.structs.conversation import Conversation
|
|
from swarms.utils.litellm_tokenizer import count_tokens
|
|
|
|
|
|
@dataclass
|
|
class TransformConfig:
|
|
"""Configuration for message transforms."""
|
|
|
|
enabled: bool = False
|
|
method: str = "middle-out"
|
|
max_tokens: Optional[int] = None
|
|
max_messages: Optional[int] = None
|
|
model_name: str = "gpt-4"
|
|
preserve_system_messages: bool = True
|
|
preserve_recent_messages: int = 2
|
|
|
|
|
|
@dataclass
|
|
class TransformResult:
|
|
"""Result of message transformation."""
|
|
|
|
messages: List[Dict[str, Any]]
|
|
original_token_count: int
|
|
compressed_token_count: int
|
|
original_message_count: int
|
|
compressed_message_count: int
|
|
compression_ratio: float
|
|
was_compressed: bool
|
|
|
|
|
|
class MessageTransforms:
|
|
"""
|
|
Handles message transformations for context size management.
|
|
|
|
Supports middle-out compression which removes or truncates messages
|
|
from the middle of the conversation while preserving the beginning
|
|
and end, which are typically more important for context.
|
|
"""
|
|
|
|
def __init__(self, config: TransformConfig):
|
|
"""
|
|
Initialize the MessageTransforms with configuration.
|
|
|
|
Args:
|
|
config: TransformConfig object with transformation settings
|
|
"""
|
|
self.config = config
|
|
|
|
def transform_messages(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
target_model: Optional[str] = None,
|
|
) -> TransformResult:
|
|
"""
|
|
Transform messages according to the configured strategy.
|
|
|
|
Args:
|
|
messages: List of message dictionaries with 'role' and 'content' keys
|
|
target_model: Optional target model name to determine context limits
|
|
|
|
Returns:
|
|
TransformResult containing transformed messages and metadata
|
|
"""
|
|
if not self.config.enabled or not messages:
|
|
return TransformResult(
|
|
messages=messages,
|
|
original_token_count=self._count_total_tokens(
|
|
messages
|
|
),
|
|
compressed_token_count=self._count_total_tokens(
|
|
messages
|
|
),
|
|
original_message_count=len(messages),
|
|
compressed_message_count=len(messages),
|
|
compression_ratio=1.0,
|
|
was_compressed=False,
|
|
)
|
|
|
|
# Use target model if provided, otherwise use config model
|
|
model_name = target_model or self.config.model_name
|
|
|
|
# Get model context limits
|
|
max_tokens = self._get_model_context_limit(model_name)
|
|
max_messages = self._get_model_message_limit(model_name)
|
|
|
|
# Override with config values if specified
|
|
if self.config.max_tokens is not None:
|
|
max_tokens = self.config.max_tokens
|
|
if self.config.max_messages is not None:
|
|
max_messages = self.config.max_messages
|
|
|
|
original_tokens = self._count_total_tokens(messages)
|
|
original_messages = len(messages)
|
|
|
|
transformed_messages = messages.copy()
|
|
|
|
# Apply transformations
|
|
if max_messages and len(transformed_messages) > max_messages:
|
|
transformed_messages = self._compress_message_count(
|
|
transformed_messages, max_messages
|
|
)
|
|
|
|
if (
|
|
max_tokens
|
|
and self._count_total_tokens(transformed_messages)
|
|
> max_tokens
|
|
):
|
|
transformed_messages = self._compress_tokens(
|
|
transformed_messages, max_tokens
|
|
)
|
|
|
|
compressed_tokens = self._count_total_tokens(
|
|
transformed_messages
|
|
)
|
|
compressed_messages = len(transformed_messages)
|
|
|
|
compression_ratio = (
|
|
compressed_tokens / original_tokens
|
|
if original_tokens > 0
|
|
else 1.0
|
|
)
|
|
|
|
return TransformResult(
|
|
messages=transformed_messages,
|
|
original_token_count=original_tokens,
|
|
compressed_token_count=compressed_tokens,
|
|
original_message_count=original_messages,
|
|
compressed_message_count=compressed_messages,
|
|
compression_ratio=compression_ratio,
|
|
was_compressed=compressed_tokens < original_tokens
|
|
or compressed_messages < original_messages,
|
|
)
|
|
|
|
def _compress_message_count(
|
|
self, messages: List[Dict[str, Any]], max_messages: int
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Compress message count using middle-out strategy.
|
|
|
|
Args:
|
|
messages: List of messages to compress
|
|
max_messages: Maximum number of messages to keep
|
|
|
|
Returns:
|
|
Compressed list of messages
|
|
"""
|
|
if len(messages) <= max_messages:
|
|
return messages
|
|
|
|
# Always preserve system messages at the beginning
|
|
system_messages = []
|
|
other_messages = []
|
|
|
|
for msg in messages:
|
|
if msg.get("role") == "system":
|
|
system_messages.append(msg)
|
|
else:
|
|
other_messages.append(msg)
|
|
|
|
# Calculate how many non-system messages we can keep
|
|
available_slots = max_messages - len(system_messages)
|
|
if available_slots <= 0:
|
|
# If we can't fit any non-system messages, just return system messages
|
|
return system_messages[:max_messages]
|
|
|
|
# Preserve recent messages
|
|
preserve_recent = min(
|
|
self.config.preserve_recent_messages, len(other_messages)
|
|
)
|
|
recent_messages = (
|
|
other_messages[-preserve_recent:]
|
|
if preserve_recent > 0
|
|
else []
|
|
)
|
|
|
|
# Calculate remaining slots for middle messages
|
|
remaining_slots = available_slots - len(recent_messages)
|
|
if remaining_slots <= 0:
|
|
# Only keep system messages and recent messages
|
|
result = system_messages + recent_messages
|
|
return result[:max_messages]
|
|
|
|
# Get messages from the beginning (excluding recent ones)
|
|
early_messages = (
|
|
other_messages[:-preserve_recent]
|
|
if preserve_recent > 0
|
|
else other_messages
|
|
)
|
|
|
|
# If we have enough slots for all early messages
|
|
if len(early_messages) <= remaining_slots:
|
|
result = (
|
|
system_messages + early_messages + recent_messages
|
|
)
|
|
return result[:max_messages]
|
|
|
|
# Apply middle-out compression to early messages
|
|
compressed_early = self._middle_out_compress(
|
|
early_messages, remaining_slots
|
|
)
|
|
|
|
result = system_messages + compressed_early + recent_messages
|
|
return result[:max_messages]
|
|
|
|
def _compress_tokens(
|
|
self, messages: List[Dict[str, Any]], max_tokens: int
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Compress messages to fit within token limit using middle-out strategy.
|
|
|
|
Args:
|
|
messages: List of messages to compress
|
|
max_tokens: Maximum token count
|
|
|
|
Returns:
|
|
Compressed list of messages
|
|
"""
|
|
current_tokens = self._count_total_tokens(messages)
|
|
|
|
if current_tokens <= max_tokens:
|
|
return messages
|
|
|
|
# First try to compress message count if we have too many messages
|
|
if (
|
|
len(messages) > 50
|
|
): # Arbitrary threshold for when to try message count compression first
|
|
messages = self._compress_message_count(
|
|
messages, len(messages) // 2
|
|
)
|
|
|
|
current_tokens = self._count_total_tokens(messages)
|
|
if current_tokens <= max_tokens:
|
|
return messages
|
|
|
|
# Apply middle-out compression with token awareness
|
|
return self._middle_out_compress_tokens(messages, max_tokens)
|
|
|
|
def _middle_out_compress(
|
|
self, messages: List[Dict[str, Any]], target_count: int
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply middle-out compression to reduce message count.
|
|
|
|
Args:
|
|
messages: Messages to compress
|
|
target_count: Target number of messages
|
|
|
|
Returns:
|
|
Compressed messages
|
|
"""
|
|
if len(messages) <= target_count:
|
|
return messages
|
|
|
|
# Keep first half and last half
|
|
keep_count = target_count // 2
|
|
first_half = messages[:keep_count]
|
|
last_half = messages[-keep_count:]
|
|
|
|
# Combine first half, last half, and if odd number, add the middle message
|
|
result = first_half + last_half
|
|
|
|
if target_count % 2 == 1 and len(messages) > keep_count * 2:
|
|
middle_index = len(messages) // 2
|
|
result.insert(keep_count, messages[middle_index])
|
|
|
|
return result[:target_count]
|
|
|
|
def _middle_out_compress_tokens(
|
|
self, messages: List[Dict[str, Any]], max_tokens: int
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply middle-out compression with token awareness.
|
|
|
|
Args:
|
|
messages: Messages to compress
|
|
max_tokens: Maximum token count
|
|
|
|
Returns:
|
|
Compressed messages
|
|
"""
|
|
# Start by keeping all messages and remove from middle until under token limit
|
|
current_messages = messages.copy()
|
|
|
|
while (
|
|
self._count_total_tokens(current_messages) > max_tokens
|
|
and len(current_messages) > 2
|
|
):
|
|
# Remove from the middle
|
|
if len(current_messages) <= 2:
|
|
break
|
|
|
|
# Find the middle message (avoiding system messages if possible)
|
|
middle_index = len(current_messages) // 2
|
|
|
|
# Try to avoid removing system messages
|
|
if current_messages[middle_index].get("role") == "system":
|
|
# Look for a non-system message near the middle
|
|
for offset in range(
|
|
1, len(current_messages) // 4 + 1
|
|
):
|
|
if (
|
|
middle_index - offset >= 0
|
|
and current_messages[
|
|
middle_index - offset
|
|
].get("role")
|
|
!= "system"
|
|
):
|
|
middle_index = middle_index - offset
|
|
break
|
|
if (
|
|
middle_index + offset < len(current_messages)
|
|
and current_messages[
|
|
middle_index + offset
|
|
].get("role")
|
|
!= "system"
|
|
):
|
|
middle_index = middle_index + offset
|
|
break
|
|
|
|
# Remove the middle message
|
|
current_messages.pop(middle_index)
|
|
|
|
return current_messages
|
|
|
|
def _count_total_tokens(
|
|
self, messages: List[Dict[str, Any]]
|
|
) -> int:
|
|
"""Count total tokens in a list of messages."""
|
|
total_tokens = 0
|
|
for message in messages:
|
|
content = message.get("content", "")
|
|
if isinstance(content, str):
|
|
total_tokens += count_tokens(
|
|
content, self.config.model_name
|
|
)
|
|
elif isinstance(content, (list, dict)):
|
|
# Handle structured content
|
|
total_tokens += count_tokens(
|
|
str(content), self.config.model_name
|
|
)
|
|
return total_tokens
|
|
|
|
def _get_model_context_limit(
|
|
self, model_name: str
|
|
) -> Optional[int]:
|
|
"""
|
|
Get the context token limit for a given model.
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
Token limit or None if unknown
|
|
"""
|
|
# Common model context limits (in tokens)
|
|
model_limits = {
|
|
"gpt-4": 8192,
|
|
"gpt-4-turbo": 128000,
|
|
"gpt-4o": 128000,
|
|
"gpt-4o-mini": 128000,
|
|
"gpt-3.5-turbo": 16385,
|
|
"claude-3-opus": 200000,
|
|
"claude-3-sonnet": 200000,
|
|
"claude-3-haiku": 200000,
|
|
"claude-3-5-sonnet": 200000,
|
|
"claude-2": 100000,
|
|
"gemini-pro": 32768,
|
|
"gemini-pro-vision": 16384,
|
|
"llama-2-7b": 4096,
|
|
"llama-2-13b": 4096,
|
|
"llama-2-70b": 4096,
|
|
}
|
|
|
|
# Check for exact match first
|
|
if model_name in model_limits:
|
|
return model_limits[model_name]
|
|
|
|
# Check for partial matches
|
|
for model_key, limit in model_limits.items():
|
|
if model_key in model_name.lower():
|
|
return limit
|
|
|
|
# Default fallback
|
|
logger.warning(
|
|
f"Unknown model '{model_name}', using default context limit of 4096 tokens"
|
|
)
|
|
return 4096
|
|
|
|
def _get_model_message_limit(
|
|
self, model_name: str
|
|
) -> Optional[int]:
|
|
"""
|
|
Get the message count limit for a given model.
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
Message limit or None if no limit
|
|
"""
|
|
# Models with known message limits
|
|
message_limits = {
|
|
"claude-3-opus": 1000,
|
|
"claude-3-sonnet": 1000,
|
|
"claude-3-haiku": 1000,
|
|
"claude-3-5-sonnet": 1000,
|
|
"claude-2": 1000,
|
|
}
|
|
|
|
# Check for exact match first
|
|
if model_name in message_limits:
|
|
return message_limits[model_name]
|
|
|
|
# Check for partial matches
|
|
for model_key, limit in message_limits.items():
|
|
if model_key in model_name.lower():
|
|
return limit
|
|
|
|
return None # No known limit
|
|
|
|
|
|
def create_default_transforms(
|
|
enabled: bool = True,
|
|
method: str = "middle-out",
|
|
model_name: str = "gpt-4",
|
|
) -> MessageTransforms:
|
|
"""
|
|
Create MessageTransforms with default configuration.
|
|
|
|
Args:
|
|
enabled: Whether transforms are enabled
|
|
method: Transform method to use
|
|
model_name: Model name for context limits
|
|
|
|
Returns:
|
|
Configured MessageTransforms instance
|
|
"""
|
|
config = TransformConfig(
|
|
enabled=enabled, method=method, model_name=model_name
|
|
)
|
|
return MessageTransforms(config)
|
|
|
|
|
|
def apply_transforms_to_messages(
|
|
messages: List[Dict[str, Any]],
|
|
transforms_config: Optional[TransformConfig] = None,
|
|
model_name: str = "gpt-4",
|
|
) -> TransformResult:
|
|
"""
|
|
Convenience function to apply transforms to messages.
|
|
|
|
Args:
|
|
messages: List of message dictionaries
|
|
transforms_config: Optional transform configuration
|
|
model_name: Model name for context determination
|
|
|
|
Returns:
|
|
TransformResult with processed messages
|
|
"""
|
|
if transforms_config is None:
|
|
transforms = create_default_transforms(
|
|
enabled=True, model_name=model_name
|
|
)
|
|
else:
|
|
transforms = MessageTransforms(transforms_config)
|
|
|
|
return transforms.transform_messages(messages, model_name)
|
|
|
|
|
|
def handle_transforms(
|
|
transforms: MessageTransforms,
|
|
short_memory: Conversation = None,
|
|
model_name: Optional[str] = "gpt-4o",
|
|
) -> str:
|
|
"""
|
|
Handle message transforms and return a formatted task prompt.
|
|
|
|
Applies message transforms to the provided messages using the given
|
|
MessageTransforms instance. If compression occurs, logs the results.
|
|
Returns the formatted string of messages after transforms, or the
|
|
original message history as a string if no transforms are enabled.
|
|
|
|
Args:
|
|
messages: List of message dictionaries to process.
|
|
transforms: MessageTransforms instance to apply.
|
|
short_memory: Object with methods to return messages as dictionary or string.
|
|
model_name: Name of the model for context.
|
|
|
|
Returns:
|
|
Formatted string of messages for the task prompt.
|
|
"""
|
|
# Get messages as dictionary format for transforms
|
|
messages_dict = short_memory.return_messages_as_dictionary()
|
|
|
|
# Apply transforms
|
|
transform_result = transforms.transform_messages(
|
|
messages_dict, model_name
|
|
)
|
|
|
|
# Log transform results if compression occurred
|
|
if transform_result.was_compressed:
|
|
logger.info(
|
|
f"Applied message transforms: {transform_result.original_message_count} -> "
|
|
f"{transform_result.compressed_message_count} messages, "
|
|
f"{transform_result.original_token_count} -> {transform_result.compressed_token_count} tokens "
|
|
f"(ratio: {transform_result.compression_ratio:.2f})"
|
|
)
|
|
|
|
# Convert transformed messages back to string format
|
|
formatted_messages = [
|
|
f"{message['role']}: {message['content']}"
|
|
for message in transform_result.messages
|
|
]
|
|
task_prompt = "\n\n".join(formatted_messages)
|
|
|
|
return task_prompt
|