diff --git a/examples/misc/conversation_cache_example.py b/examples/misc/conversation_cache_example.py new file mode 100644 index 00000000..8f6c1da5 --- /dev/null +++ b/examples/misc/conversation_cache_example.py @@ -0,0 +1,69 @@ +import importlib.util +import sys +import time +from pathlib import Path +from types import ModuleType + +# Load Conversation without importing full swarms package +CONV_PATH = ( + Path(__file__).resolve().parents[2] + / "swarms" + / "structs" + / "conversation.py" +) +spec = importlib.util.spec_from_file_location( + "swarms.structs.conversation", CONV_PATH +) +conversation = importlib.util.module_from_spec(spec) +sys.modules.setdefault("swarms", ModuleType("swarms")) +sys.modules.setdefault("swarms.structs", ModuleType("swarms.structs")) +base_module = ModuleType("swarms.structs.base_structure") +base_module.BaseStructure = object +sys.modules["swarms.structs.base_structure"] = base_module +utils_root = ModuleType("swarms.utils") +any_to_str_mod = ModuleType("swarms.utils.any_to_str") +formatter_mod = ModuleType("swarms.utils.formatter") +token_mod = ModuleType("swarms.utils.litellm_tokenizer") +any_to_str_mod.any_to_str = lambda x: str(x) +formatter_mod.formatter = type( + "Formatter", (), {"print_panel": lambda *a, **k: None} +)() +token_mod.count_tokens = lambda s: len(str(s).split()) +sys.modules["swarms.utils"] = utils_root +sys.modules["swarms.utils.any_to_str"] = any_to_str_mod +sys.modules["swarms.utils.formatter"] = formatter_mod +sys.modules["swarms.utils.litellm_tokenizer"] = token_mod +spec.loader.exec_module(conversation) +Conversation = conversation.Conversation + +# Demonstrate cached conversation history +conv = Conversation() +conv.add("user", "Hello") +conv.add("assistant", "Hi there!") + +print(conv.get_str()) +print(conv.get_str()) # reuses cached string + +# Timing demo +start = time.perf_counter() +for _ in range(1000): + conv.get_str() +cached_time = time.perf_counter() - start +print("Cached retrieval:", round(cached_time, 6), "seconds") + + +# Compare to rebuilding manually +def slow_get(): + formatted = [ + f"{m['role']}: {m['content']}" + for m in conv.conversation_history + ] + return "\n\n".join(formatted) + + +start = time.perf_counter() +for _ in range(1000): + slow_get() +slow_time = time.perf_counter() - start +print("Manual join:", round(slow_time, 6), "seconds") +print("Speedup:", round(slow_time / cached_time, 2), "x") diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 4064620b..a3588e87 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -994,9 +994,7 @@ class Agent: self.dynamic_temperature() # Task prompt - task_prompt = ( - self.short_memory.return_history_as_string() - ) + task_prompt = self.short_memory.get_str() # Parameters attempt = 0 @@ -2119,9 +2117,7 @@ class Agent: def check_available_tokens(self): # Log the amount of tokens left in the memory and in the task if self.tokenizer is not None: - tokens_used = count_tokens( - self.short_memory.return_history_as_string() - ) + tokens_used = count_tokens(self.short_memory.get_str()) logger.info( f"Tokens available: {self.context_length - tokens_used}" ) @@ -2130,9 +2126,7 @@ class Agent: def tokens_checks(self): # Check the tokens available - tokens_used = count_tokens( - self.short_memory.return_history_as_string() - ) + tokens_used = count_tokens(self.short_memory.get_str()) out = self.check_available_tokens() logger.info( diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index a87b1579..6d5706e9 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -137,6 +137,10 @@ class Conversation(BaseStructure): self.token_count = token_count self.provider = provider + # Cache for history string to avoid repeated joins + self._history_cache = "" + self._cache_index = 0 + # Create conversation directory if saving is enabled if self.save_enabled and self.conversations_dir: os.makedirs(self.conversations_dir, exist_ok=True) @@ -144,6 +148,14 @@ class Conversation(BaseStructure): # Try to load existing conversation or initialize new one self.setup() + def _refresh_cache(self): + """Refresh cached history string.""" + self._history_cache = "\n\n".join( + f"{m['role']}: {m['content']}" + for m in self.conversation_history + ) + self._cache_index = len(self.conversation_history) + def setup(self): """Set up the conversation by either loading existing data or initializing new.""" if self.load_filepath and os.path.exists(self.load_filepath): @@ -239,6 +251,14 @@ class Conversation(BaseStructure): # Add message to conversation history self.conversation_history.append(message) + # Incrementally update history cache + formatted = f"{role}: {content}" + if self._history_cache: + self._history_cache += "\n\n" + formatted + else: + self._history_cache = formatted + self._cache_index = len(self.conversation_history) + if self.token_count is True: self._count_tokens(content, message) @@ -452,13 +472,14 @@ class Conversation(BaseStructure): Returns: str: The conversation history formatted as a string. """ - formatted_messages = [] - for message in self.conversation_history: - formatted_messages.append( - f"{message['role']}: {message['content']}" + # If new messages were added outside the add method, rebuild cache + if self._cache_index < len(self.conversation_history): + self._history_cache = "\n\n".join( + f"{m['role']}: {m['content']}" + for m in self.conversation_history ) - - return "\n\n".join(formatted_messages) + self._cache_index = len(self.conversation_history) + return self._history_cache def get_str(self) -> str: """Get the conversation history as a string. @@ -542,6 +563,9 @@ class Conversation(BaseStructure): # Load conversation history self.conversation_history = data.get("history", []) + # Rebuild cache from loaded history + self._refresh_cache() + logger.info( f"Successfully loaded conversation from {filename}" ) @@ -604,6 +628,8 @@ class Conversation(BaseStructure): def clear(self): """Clear the conversation history.""" self.conversation_history = [] + self._history_cache = "" + self._cache_index = 0 def to_json(self): """Convert the conversation history to a JSON string. diff --git a/tests/structs/test_conversation_cache_perf.py b/tests/structs/test_conversation_cache_perf.py new file mode 100644 index 00000000..148723fd --- /dev/null +++ b/tests/structs/test_conversation_cache_perf.py @@ -0,0 +1,65 @@ +import importlib.util +import sys +import time +from pathlib import Path +from types import ModuleType + +# Load Conversation without importing the full swarms package +CONV_PATH = ( + Path(__file__).resolve().parents[2] + / "swarms" + / "structs" + / "conversation.py" +) +spec = importlib.util.spec_from_file_location( + "swarms.structs.conversation", CONV_PATH +) +conversation = importlib.util.module_from_spec(spec) +sys.modules.setdefault("swarms", ModuleType("swarms")) +sys.modules.setdefault("swarms.structs", ModuleType("swarms.structs")) +base_module = ModuleType("swarms.structs.base_structure") +base_module.BaseStructure = object +sys.modules["swarms.structs.base_structure"] = base_module +utils_root = ModuleType("swarms.utils") +any_to_str_mod = ModuleType("swarms.utils.any_to_str") +formatter_mod = ModuleType("swarms.utils.formatter") +token_mod = ModuleType("swarms.utils.litellm_tokenizer") +any_to_str_mod.any_to_str = lambda x: str(x) +formatter_mod.formatter = type( + "Formatter", (), {"print_panel": lambda *a, **k: None} +)() +token_mod.count_tokens = lambda s: len(str(s).split()) +sys.modules["swarms.utils"] = utils_root +sys.modules["swarms.utils.any_to_str"] = any_to_str_mod +sys.modules["swarms.utils.formatter"] = formatter_mod +sys.modules["swarms.utils.litellm_tokenizer"] = token_mod +spec.loader.exec_module(conversation) +Conversation = conversation.Conversation + + +class OldConversation(Conversation): + def return_history_as_string(self): + formatted = [ + f"{m['role']}: {m['content']}" + for m in self.conversation_history + ] + return "\n\n".join(formatted) + + def get_str(self): + return self.return_history_as_string() + + +def measure(conv_cls, messages=50, loops=1000): + conv = conv_cls(token_count=False) + for i in range(messages): + conv.add("user", f"msg{i}") + start = time.perf_counter() + for _ in range(loops): + conv.get_str() + return time.perf_counter() - start + + +def test_cache_perf_improvement(): + old_time = measure(OldConversation) + new_time = measure(Conversation) + assert old_time / new_time >= 2 diff --git a/tests/structs/test_conversation_cache_simple.py b/tests/structs/test_conversation_cache_simple.py new file mode 100644 index 00000000..b59fed4a --- /dev/null +++ b/tests/structs/test_conversation_cache_simple.py @@ -0,0 +1,51 @@ +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +# Load Conversation without importing swarms package to avoid optional deps +CONV_PATH = ( + Path(__file__).resolve().parents[2] + / "swarms" + / "structs" + / "conversation.py" +) +spec = importlib.util.spec_from_file_location( + "swarms.structs.conversation", CONV_PATH +) +conversation = importlib.util.module_from_spec(spec) +sys.modules.setdefault("swarms", ModuleType("swarms")) +structs_module = ModuleType("swarms.structs") +sys.modules.setdefault("swarms.structs", structs_module) +# Minimal BaseStructure to satisfy Conversation import +base_module = ModuleType("swarms.structs.base_structure") +base_module.BaseStructure = object +sys.modules["swarms.structs.base_structure"] = base_module +utils_root = ModuleType("swarms.utils") +any_to_str_mod = ModuleType("swarms.utils.any_to_str") +formatter_mod = ModuleType("swarms.utils.formatter") +token_mod = ModuleType("swarms.utils.litellm_tokenizer") +any_to_str_mod.any_to_str = lambda x: str(x) +formatter_mod.formatter = type( + "Formatter", (), {"print_panel": lambda *a, **k: None} +)() +token_mod.count_tokens = lambda s: len(str(s).split()) +sys.modules["swarms.utils"] = utils_root +sys.modules["swarms.utils.any_to_str"] = any_to_str_mod +sys.modules["swarms.utils.formatter"] = formatter_mod +sys.modules["swarms.utils.litellm_tokenizer"] = token_mod +spec.loader.exec_module(conversation) +Conversation = conversation.Conversation + + +def test_history_cache_updates_incrementally(): + conv = Conversation(token_count=False) + conv.add("user", "Hello") + first_cache = conv.get_str() + assert first_cache == "user: Hello" + conv.add("assistant", "Hi") + second_cache = conv.get_str() + assert second_cache.endswith("assistant: Hi") + assert conv._cache_index == len(conv.conversation_history) + # Ensure cache reused when no new messages + assert conv.get_str() is second_cache