You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
420 lines
13 KiB
420 lines
13 KiB
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import yaml
|
|
from pydantic import BaseModel
|
|
from swarm_models.tiktoken_wrapper import TikTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MemoryMetadata(BaseModel):
|
|
"""Metadata for memory entries"""
|
|
|
|
timestamp: Optional[float] = time.time()
|
|
role: Optional[str] = None
|
|
agent_name: Optional[str] = None
|
|
session_id: Optional[str] = None
|
|
memory_type: Optional[str] = None # 'short_term' or 'long_term'
|
|
token_count: Optional[int] = None
|
|
message_id: Optional[str] = str(uuid.uuid4())
|
|
|
|
|
|
class MemoryEntry(BaseModel):
|
|
"""Single memory entry with content and metadata"""
|
|
|
|
content: Optional[str] = None
|
|
metadata: Optional[MemoryMetadata] = None
|
|
|
|
|
|
class MemoryConfig(BaseModel):
|
|
"""Configuration for memory manager"""
|
|
|
|
max_short_term_tokens: Optional[int] = 4096
|
|
max_entries: Optional[int] = None
|
|
system_messages_token_buffer: Optional[int] = 1000
|
|
enable_long_term_memory: Optional[bool] = False
|
|
auto_archive: Optional[bool] = True
|
|
archive_threshold: Optional[float] = 0.8 # Archive when 80% full
|
|
|
|
|
|
class MemoryManager:
|
|
"""
|
|
Manages both short-term and long-term memory for an agent, handling token limits,
|
|
archival, and context retrieval.
|
|
|
|
Args:
|
|
config (MemoryConfig): Configuration for memory management
|
|
tokenizer (Optional[Any]): Tokenizer to use for token counting
|
|
long_term_memory (Optional[Any]): Vector store or database for long-term storage
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: MemoryConfig,
|
|
tokenizer: Optional[Any] = None,
|
|
long_term_memory: Optional[Any] = None,
|
|
):
|
|
self.config = config
|
|
self.tokenizer = tokenizer or TikTokenizer()
|
|
self.long_term_memory = long_term_memory
|
|
|
|
# Initialize memories
|
|
self.short_term_memory: List[MemoryEntry] = []
|
|
self.system_messages: List[MemoryEntry] = []
|
|
|
|
# Memory statistics
|
|
self.total_tokens_processed: int = 0
|
|
self.archived_entries_count: int = 0
|
|
|
|
def create_memory_entry(
|
|
self,
|
|
content: str,
|
|
role: str,
|
|
agent_name: str,
|
|
session_id: str,
|
|
memory_type: str = "short_term",
|
|
) -> MemoryEntry:
|
|
"""Create a new memory entry with metadata"""
|
|
metadata = MemoryMetadata(
|
|
timestamp=time.time(),
|
|
role=role,
|
|
agent_name=agent_name,
|
|
session_id=session_id,
|
|
memory_type=memory_type,
|
|
token_count=self.tokenizer.count_tokens(content),
|
|
)
|
|
return MemoryEntry(content=content, metadata=metadata)
|
|
|
|
def add_memory(
|
|
self,
|
|
content: str,
|
|
role: str,
|
|
agent_name: str,
|
|
session_id: str,
|
|
is_system: bool = False,
|
|
) -> None:
|
|
"""Add a new memory entry to appropriate storage"""
|
|
entry = self.create_memory_entry(
|
|
content=content,
|
|
role=role,
|
|
agent_name=agent_name,
|
|
session_id=session_id,
|
|
memory_type="system" if is_system else "short_term",
|
|
)
|
|
|
|
if is_system:
|
|
self.system_messages.append(entry)
|
|
else:
|
|
self.short_term_memory.append(entry)
|
|
|
|
# Check if archiving is needed
|
|
if self.should_archive():
|
|
self.archive_old_memories()
|
|
|
|
self.total_tokens_processed += entry.metadata.token_count
|
|
|
|
def get_current_token_count(self) -> int:
|
|
"""Get total tokens in short-term memory"""
|
|
return sum(
|
|
entry.metadata.token_count
|
|
for entry in self.short_term_memory
|
|
)
|
|
|
|
def get_system_messages_token_count(self) -> int:
|
|
"""Get total tokens in system messages"""
|
|
return sum(
|
|
entry.metadata.token_count
|
|
for entry in self.system_messages
|
|
)
|
|
|
|
def should_archive(self) -> bool:
|
|
"""Check if archiving is needed based on configuration"""
|
|
if not self.config.auto_archive:
|
|
return False
|
|
|
|
current_usage = (
|
|
self.get_current_token_count()
|
|
/ self.config.max_short_term_tokens
|
|
)
|
|
return current_usage >= self.config.archive_threshold
|
|
|
|
def archive_old_memories(self) -> None:
|
|
"""Move older memories to long-term storage"""
|
|
if not self.long_term_memory:
|
|
logger.warning(
|
|
"No long-term memory storage configured for archiving"
|
|
)
|
|
return
|
|
|
|
while self.should_archive():
|
|
# Get oldest non-system message
|
|
if not self.short_term_memory:
|
|
break
|
|
|
|
oldest_entry = self.short_term_memory.pop(0)
|
|
|
|
# Store in long-term memory
|
|
self.store_in_long_term_memory(oldest_entry)
|
|
self.archived_entries_count += 1
|
|
|
|
def store_in_long_term_memory(self, entry: MemoryEntry) -> None:
|
|
"""Store a memory entry in long-term memory"""
|
|
if self.long_term_memory is None:
|
|
logger.warning(
|
|
"Attempted to store in non-existent long-term memory"
|
|
)
|
|
return
|
|
|
|
try:
|
|
self.long_term_memory.add(str(entry.model_dump()))
|
|
except Exception as e:
|
|
logger.error(f"Error storing in long-term memory: {e}")
|
|
# Re-add to short-term if storage fails
|
|
self.short_term_memory.insert(0, entry)
|
|
|
|
def get_relevant_context(
|
|
self, query: str, max_tokens: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Get relevant context from both memory types
|
|
|
|
Args:
|
|
query (str): Query to match against memories
|
|
max_tokens (Optional[int]): Maximum tokens to return
|
|
|
|
Returns:
|
|
str: Combined relevant context
|
|
"""
|
|
contexts = []
|
|
|
|
# Add system messages first
|
|
for entry in self.system_messages:
|
|
contexts.append(entry.content)
|
|
|
|
# Add short-term memory
|
|
for entry in reversed(self.short_term_memory):
|
|
contexts.append(entry.content)
|
|
|
|
# Query long-term memory if available
|
|
if self.long_term_memory is not None:
|
|
long_term_context = self.long_term_memory.query(query)
|
|
if long_term_context:
|
|
contexts.append(str(long_term_context))
|
|
|
|
# Combine and truncate if needed
|
|
combined = "\n".join(contexts)
|
|
if max_tokens:
|
|
combined = self.truncate_to_token_limit(
|
|
combined, max_tokens
|
|
)
|
|
|
|
return combined
|
|
|
|
def truncate_to_token_limit(
|
|
self, text: str, max_tokens: int
|
|
) -> str:
|
|
"""Truncate text to fit within token limit"""
|
|
current_tokens = self.tokenizer.count_tokens(text)
|
|
|
|
if current_tokens <= max_tokens:
|
|
return text
|
|
|
|
# Truncate by splitting into sentences and rebuilding
|
|
sentences = text.split(". ")
|
|
result = []
|
|
current_count = 0
|
|
|
|
for sentence in sentences:
|
|
sentence_tokens = self.tokenizer.count_tokens(sentence)
|
|
if current_count + sentence_tokens <= max_tokens:
|
|
result.append(sentence)
|
|
current_count += sentence_tokens
|
|
else:
|
|
break
|
|
|
|
return ". ".join(result)
|
|
|
|
def clear_short_term_memory(
|
|
self, preserve_system: bool = True
|
|
) -> None:
|
|
"""Clear short-term memory with option to preserve system messages"""
|
|
if not preserve_system:
|
|
self.system_messages.clear()
|
|
self.short_term_memory.clear()
|
|
logger.info(
|
|
"Cleared short-term memory"
|
|
+ " (preserved system messages)"
|
|
if preserve_system
|
|
else ""
|
|
)
|
|
|
|
def get_memory_stats(self) -> Dict[str, Any]:
|
|
"""Get detailed memory statistics"""
|
|
return {
|
|
"short_term_messages": len(self.short_term_memory),
|
|
"system_messages": len(self.system_messages),
|
|
"current_tokens": self.get_current_token_count(),
|
|
"system_tokens": self.get_system_messages_token_count(),
|
|
"max_tokens": self.config.max_short_term_tokens,
|
|
"token_usage_percent": round(
|
|
(
|
|
self.get_current_token_count()
|
|
/ self.config.max_short_term_tokens
|
|
)
|
|
* 100,
|
|
2,
|
|
),
|
|
"has_long_term_memory": self.long_term_memory is not None,
|
|
"archived_entries": self.archived_entries_count,
|
|
"total_tokens_processed": self.total_tokens_processed,
|
|
}
|
|
|
|
def save_memory_snapshot(self, file_path: str) -> None:
|
|
"""Save current memory state to file"""
|
|
try:
|
|
data = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"config": self.config.model_dump(),
|
|
"system_messages": [
|
|
entry.model_dump()
|
|
for entry in self.system_messages
|
|
],
|
|
"short_term_memory": [
|
|
entry.model_dump()
|
|
for entry in self.short_term_memory
|
|
],
|
|
"stats": self.get_memory_stats(),
|
|
}
|
|
|
|
with open(file_path, "w") as f:
|
|
if file_path.endswith(".yaml"):
|
|
yaml.dump(data, f)
|
|
else:
|
|
json.dump(data, f, indent=2)
|
|
|
|
logger.info(f"Saved memory snapshot to {file_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving memory snapshot: {e}")
|
|
raise
|
|
|
|
def load_memory_snapshot(self, file_path: str) -> None:
|
|
"""Load memory state from file"""
|
|
try:
|
|
with open(file_path, "r") as f:
|
|
if file_path.endswith(".yaml"):
|
|
data = yaml.safe_load(f)
|
|
else:
|
|
data = json.load(f)
|
|
|
|
self.config = MemoryConfig(**data["config"])
|
|
self.system_messages = [
|
|
MemoryEntry(**entry)
|
|
for entry in data["system_messages"]
|
|
]
|
|
self.short_term_memory = [
|
|
MemoryEntry(**entry)
|
|
for entry in data["short_term_memory"]
|
|
]
|
|
|
|
logger.info(f"Loaded memory snapshot from {file_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading memory snapshot: {e}")
|
|
raise
|
|
|
|
def search_memories(
|
|
self, query: str, memory_type: str = "all"
|
|
) -> List[MemoryEntry]:
|
|
"""
|
|
Search through memories of specified type
|
|
|
|
Args:
|
|
query (str): Search query
|
|
memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all")
|
|
|
|
Returns:
|
|
List[MemoryEntry]: Matching memory entries
|
|
"""
|
|
results = []
|
|
|
|
if memory_type in ["short_term", "all"]:
|
|
results.extend(
|
|
[
|
|
entry
|
|
for entry in self.short_term_memory
|
|
if query.lower() in entry.content.lower()
|
|
]
|
|
)
|
|
|
|
if memory_type in ["system", "all"]:
|
|
results.extend(
|
|
[
|
|
entry
|
|
for entry in self.system_messages
|
|
if query.lower() in entry.content.lower()
|
|
]
|
|
)
|
|
|
|
if (
|
|
memory_type in ["long_term", "all"]
|
|
and self.long_term_memory is not None
|
|
):
|
|
long_term_results = self.long_term_memory.query(query)
|
|
if long_term_results:
|
|
# Convert long-term results to MemoryEntry format
|
|
for result in long_term_results:
|
|
content = str(result)
|
|
metadata = MemoryMetadata(
|
|
timestamp=time.time(),
|
|
role="long_term",
|
|
agent_name="system",
|
|
session_id="long_term",
|
|
memory_type="long_term",
|
|
token_count=self.tokenizer.count_tokens(
|
|
content
|
|
),
|
|
)
|
|
results.append(
|
|
MemoryEntry(
|
|
content=content, metadata=metadata
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
def get_memory_by_timeframe(
|
|
self, start_time: float, end_time: float
|
|
) -> List[MemoryEntry]:
|
|
"""Get memories within a specific timeframe"""
|
|
return [
|
|
entry
|
|
for entry in self.short_term_memory
|
|
if start_time <= entry.metadata.timestamp <= end_time
|
|
]
|
|
|
|
def export_memories(
|
|
self, file_path: str, format: str = "json"
|
|
) -> None:
|
|
"""Export memories to file in specified format"""
|
|
data = {
|
|
"system_messages": [
|
|
entry.model_dump() for entry in self.system_messages
|
|
],
|
|
"short_term_memory": [
|
|
entry.model_dump() for entry in self.short_term_memory
|
|
],
|
|
"stats": self.get_memory_stats(),
|
|
}
|
|
|
|
with open(file_path, "w") as f:
|
|
if format == "yaml":
|
|
yaml.dump(data, f)
|
|
else:
|
|
json.dump(data, f, indent=2)
|