pull/903/merge
harshalmore31 1 month ago committed by GitHub
commit 8e811baca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,156 @@
"""
Simple RAG Example with Swarms Framework
A concise example showing how to use the RAG integration with Swarms Agent.
This example demonstrates the core RAG functionality in a simple, easy-to-understand way.
"""
import time
from swarms.structs import Agent, RAGConfig
class SimpleMemoryStore:
"""Simple in-memory memory store for demonstration"""
def __init__(self):
self.memories = []
def add(self, content: str, metadata: dict = None) -> bool:
"""Add content to memory"""
self.memories.append({
'content': content,
'metadata': metadata or {},
'timestamp': time.time()
})
return True
def query(self, query: str, top_k: int = 3, similarity_threshold: float = 0.5) -> list:
"""Simple keyword-based query"""
query_lower = query.lower()
results = []
for memory in self.memories:
content_lower = memory['content'].lower()
# Simple relevance score
relevance = sum(1 for word in query_lower.split() if word in content_lower)
relevance = min(relevance / len(query_lower.split()), 1.0)
if relevance >= similarity_threshold:
results.append({
'content': memory['content'],
'score': relevance,
'metadata': memory['metadata']
})
return sorted(results, key=lambda x: x['score'], reverse=True)[:top_k]
def main():
"""Main example demonstrating RAG functionality"""
print("🚀 Simple RAG Example with Swarms Framework")
print("=" * 50)
# 1. Initialize memory store
print("\n1. Setting up memory store...")
memory_store = SimpleMemoryStore()
# Add some knowledge to memory
knowledge_items = [
"Python is a versatile programming language used for web development, data science, and AI.",
"Machine learning models learn patterns from data to make predictions.",
"The Swarms framework enables building sophisticated multi-agent systems.",
"RAG (Retrieval-Augmented Generation) enhances AI responses with external knowledge.",
"Vector databases store embeddings for efficient similarity search."
]
for item in knowledge_items:
memory_store.add(item, {'source': 'knowledge_base'})
print(f"✅ Added {len(knowledge_items)} knowledge items to memory")
# 2. Configure RAG
print("\n2. Configuring RAG...")
rag_config = RAGConfig(
similarity_threshold=0.3, # Lower threshold for demo
max_results=2,
auto_save_to_memory=True,
query_every_loop=False, # Disable to avoid issues
enable_conversation_summaries=True
)
# 3. Create agent with RAG - using built-in model handling
agent = Agent(
model_name="gpt-4o-mini", # Direct model specification
temperature=0.7,
max_tokens=300,
agent_name="RAG-Demo-Agent",
long_term_memory=memory_store,
rag_config=rag_config,
max_loops=1, # Reduce loops to avoid issues
verbose=True
)
print(f"✅ Agent created with RAG enabled: {agent.is_rag_enabled()}")
# 4. Test RAG functionality
print("\n4. Testing RAG functionality...")
test_queries = [
"What is Python used for?",
"How do machine learning models work?",
"What is the Swarms framework?",
"Explain RAG systems"
]
for i, query in enumerate(test_queries, 1):
print(f"\n--- Query {i}: {query} ---")
try:
# Run the agent
response = agent.run(query)
print(f"🤖 Response: {response}")
# Check RAG stats
stats = agent.get_rag_stats()
print(f"📊 RAG Stats: {stats.get('loops_processed', 0)} loops processed")
except Exception as e:
print(f"❌ Error: {e}")
time.sleep(1)
try:
# Save custom content
success = agent.save_to_rag_memory(
"Custom knowledge: The agent successfully used RAG to enhance responses.",
{'source': 'manual_test'}
)
print(f"💾 Manual save: {success}")
# Query memory directly
result = agent.query_rag_memory("What is custom knowledge?")
print(f"🔍 Direct query result: {result[:100]}...")
# Search memories
search_results = agent.search_memories("Python", top_k=2)
print(f"🔎 Search results: {len(search_results)} items found")
except Exception as e:
print(f"❌ Error in manual operations: {e}")
# 6. Final statistics
print("\n6. Final RAG statistics...")
try:
final_stats = agent.get_rag_stats()
print(f"📈 Final Stats: {final_stats}")
except Exception as e:
print(f"❌ Error getting stats: {e}")
print("\n🎉 RAG example completed successfully!")
print("=" * 50)
if __name__ == "__main__":
main()

@ -1,5 +1,6 @@
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
from swarms.structs.agent_builder import AgentsBuilder from swarms.structs.agent_builder import AgentsBuilder
from swarms.structs.agent_rag_handler import RAGConfig
from swarms.structs.auto_swarm_builder import AutoSwarmBuilder from swarms.structs.auto_swarm_builder import AutoSwarmBuilder
from swarms.structs.base_structure import BaseStructure from swarms.structs.base_structure import BaseStructure
from swarms.structs.base_swarm import BaseSwarm from swarms.structs.base_swarm import BaseSwarm
@ -156,4 +157,5 @@ __all__ = [
"find_agent_by_name", "find_agent_by_name",
"run_agent", "run_agent",
"InteractiveGroupChat", "InteractiveGroupChat",
"RAGConfig",
] ]

@ -244,6 +244,7 @@ class Agent:
artifacts_output_path (str): The artifacts output path artifacts_output_path (str): The artifacts output path
artifacts_file_extension (str): The artifacts file extension (.pdf, .md, .txt, ) artifacts_file_extension (str): The artifacts file extension (.pdf, .md, .txt, )
scheduled_run_date (datetime): The date and time to schedule the task scheduled_run_date (datetime): The date and time to schedule the task
rag_config (RAGConfig): Configuration for RAG (Retrieval-Augmented Generation) operations
Methods: Methods:
run: Run the agent run: Run the agent
@ -277,6 +278,18 @@ class Agent:
construct_dynamic_prompt: Construct the dynamic prompt construct_dynamic_prompt: Construct the dynamic prompt
handle_artifacts: Handle artifacts handle_artifacts: Handle artifacts
# RAG (Retrieval-Augmented Generation) Methods:
enable_rag: Enable RAG functionality with optional memory store and configuration
disable_rag: Disable RAG functionality
is_rag_enabled: Check if RAG functionality is enabled
get_rag_config: Get current RAG configuration
set_rag_config: Set RAG configuration
save_to_rag_memory: Manually save content to RAG memory
query_rag_memory: Manually query RAG memory
get_rag_stats: Get RAG handler statistics
search_memories: Search long-term memory using RAG handler
update_rag_config: Update RAG configuration
clear_rag_session: Clear RAG session data
Examples: Examples:
>>> from swarm_models import OpenAIChat >>> from swarm_models import OpenAIChat
@ -589,16 +602,29 @@ class Agent:
if self.random_models_on is True: if self.random_models_on is True:
self.model_name = set_random_models_for_agents() self.model_name = set_random_models_for_agents()
# Initialize RAG handler with the new comprehensive handler
if self.long_term_memory is not None: if self.long_term_memory is not None:
self.rag_handler = self.rag_setup_handling() self.rag_handler = AgentRAGHandler(
long_term_memory=self.long_term_memory,
config=self.rag_config,
agent_name=self.agent_name,
max_context_length=self.context_length,
verbose=self.verbose,
)
else:
self.rag_handler = None
def rag_setup_handling(self): def rag_setup_handling(self):
return AgentRAGHandler( """Legacy method - now handled by AgentRAGHandler initialization"""
if self.rag_handler is None and self.long_term_memory is not None:
self.rag_handler = AgentRAGHandler(
long_term_memory=self.long_term_memory, long_term_memory=self.long_term_memory,
config=self.rag_config, config=self.rag_config,
agent_name=self.agent_name, agent_name=self.agent_name,
max_context_length=self.context_length,
verbose=self.verbose, verbose=self.verbose,
) )
return self.rag_handler
def tool_handling(self): def tool_handling(self):
@ -971,9 +997,14 @@ class Agent:
# Clear the short memory # Clear the short memory
response = None response = None
# Query the long term memory first for the context # Query the long term memory first for the context using new RAG handler
if self.long_term_memory is not None: if self.rag_handler is not None:
self.memory_query(task) retrieved_context = self.rag_handler.handle_initial_memory_query(task)
if retrieved_context:
self.short_memory.add(
role="Database",
content=retrieved_context,
)
# Autosave # Autosave
if self.autosave: if self.autosave:
@ -987,6 +1018,9 @@ class Agent:
f"Task Request for {self.agent_name}", f"Task Request for {self.agent_name}",
) )
# Track tools used for final summary
tools_used = []
while ( while (
self.max_loops == "auto" self.max_loops == "auto"
or loop_count < self.max_loops or loop_count < self.max_loops
@ -1018,28 +1052,25 @@ class Agent:
# Parameters # Parameters
attempt = 0 attempt = 0
success = False success = False
has_tool_usage = False
while attempt < self.retry_attempts and not success: while attempt < self.retry_attempts and not success:
try: try:
if ( # Handle RAG operations for this loop using new handler
self.long_term_memory is not None if self.rag_handler is not None:
and self.rag_every_loop is True retrieved_context = self.rag_handler.handle_loop_memory_operations(
): task=task,
logger.info( response=response if response else "",
"Querying RAG database for context..." loop_count=loop_count,
conversation_context=task_prompt,
has_tool_usage=has_tool_usage,
) )
self.memory_query(task_prompt)
# # Generate response using LLM
# response_args = (
# (task_prompt, *args)
# if img is None
# else (task_prompt, img, *args)
# )
# # Call the LLM if retrieved_context:
# response = self.call_llm( self.short_memory.add(
# *response_args, **kwargs role="Database",
# ) content=retrieved_context,
)
response = self.call_llm( response = self.call_llm(
task=task_prompt, img=img, *args, **kwargs task=task_prompt, img=img, *args, **kwargs
@ -1066,13 +1097,14 @@ class Agent:
# Check and execute tools # Check and execute tools
if exists(self.tools): if exists(self.tools):
has_tool_usage = True
self.execute_tools( self.execute_tools(
response=response, response=response,
loop_count=loop_count, loop_count=loop_count,
) )
if exists(self.mcp_url): if exists(self.mcp_url):
has_tool_usage = True
self.mcp_tool_handling( self.mcp_tool_handling(
response, loop_count response, loop_count
) )
@ -1080,6 +1112,7 @@ class Agent:
if exists(self.mcp_url) and exists( if exists(self.mcp_url) and exists(
self.tools self.tools
): ):
has_tool_usage = True
self.mcp_tool_handling( self.mcp_tool_handling(
response, loop_count response, loop_count
) )
@ -1155,6 +1188,15 @@ class Agent:
) )
time.sleep(self.loop_interval) time.sleep(self.loop_interval)
# Handle final memory consolidation using new RAG handler
if self.rag_handler is not None:
self.rag_handler.handle_final_memory_consolidation(
task=task,
final_response=response,
total_loops=loop_count,
tools_used=tools_used,
)
if self.autosave is True: if self.autosave is True:
log_agent_data(self.to_dict()) log_agent_data(self.to_dict())
@ -1569,6 +1611,24 @@ class Agent:
f"Could not save memory manager: {e}" f"Could not save memory manager: {e}"
) )
# Save RAG handler stats if it exists
if (
hasattr(self, "rag_handler")
and self.rag_handler is not None
):
rag_stats_path = f"{os.path.splitext(base_path)[0]}_rag_stats.json"
try:
rag_stats = self.rag_handler.get_memory_stats()
with open(rag_stats_path, 'w') as f:
json.dump(rag_stats, f, indent=2)
logger.info(
f"Saved RAG handler stats to: {rag_stats_path}"
)
except Exception as e:
logger.warning(
f"Could not save RAG handler stats: {e}"
)
except Exception as e: except Exception as e:
logger.warning(f"Error saving additional components: {e}") logger.warning(f"Error saving additional components: {e}")
@ -1696,6 +1756,20 @@ class Agent:
) as executor: ) as executor:
self.executor = executor self.executor = executor
# Reinitialize RAG handler if needed
if (
hasattr(self, "long_term_memory")
and self.long_term_memory is not None
and (not hasattr(self, "rag_handler") or self.rag_handler is None)
):
self.rag_handler = AgentRAGHandler(
long_term_memory=self.long_term_memory,
config=getattr(self, "rag_config", None),
agent_name=self.agent_name,
max_context_length=self.context_length,
verbose=self.verbose,
)
# # Reinitialize tool structure if needed # # Reinitialize tool structure if needed
# if hasattr(self, 'tools') and (self.tools or getattr(self, 'list_base_models', None)): # if hasattr(self, 'tools') and (self.tools or getattr(self, 'list_base_models', None)):
# self.tool_struct = BaseTool( # self.tool_struct = BaseTool(
@ -2017,37 +2091,23 @@ class Agent:
raise error raise error
def memory_query(self, task: str = None, *args, **kwargs) -> None: def memory_query(self, task: str = None, *args, **kwargs) -> None:
try: """Legacy method - now uses AgentRAGHandler"""
# Query the long term memory if self.rag_handler is None:
if self.long_term_memory is not None: return None
formatter.print_panel(f"Querying RAG for: {task}")
memory_retrieval = self.long_term_memory.query(
task, *args, **kwargs
)
memory_retrieval = (
f"Documents Available: {str(memory_retrieval)}"
)
# # Count the tokens try:
# memory_token_count = count_tokens( # Use the new RAG handler for initial memory query
# memory_retrieval retrieved_context = self.rag_handler.handle_initial_memory_query(task)
# )
# if memory_token_count > self.memory_chunk_size:
# # Truncate the memory by the memory chunk size
# memory_retrieval = self.truncate_string_by_tokens(
# memory_retrieval, self.memory_chunk_size
# )
if retrieved_context:
self.short_memory.add( self.short_memory.add(
role="Database", role="Database",
content=memory_retrieval, content=retrieved_context,
) )
return None return None
except Exception as e: except Exception as e:
logger.error(f"An error occurred: {e}") logger.error(f"An error occurred during memory query: {e}")
raise e raise e
def sentiment_analysis_handler(self, response: str = None): def sentiment_analysis_handler(self, response: str = None):
@ -2845,3 +2905,85 @@ class Agent:
def list_output_types(self): def list_output_types(self):
return OutputType return OutputType
def get_rag_stats(self) -> Dict[str, Any]:
"""Get RAG handler statistics"""
if self.rag_handler is None:
return {"rag_enabled": False}
return self.rag_handler.get_memory_stats()
def search_memories(self, query: str, top_k: int = None, similarity_threshold: float = None) -> List[Dict]:
"""Search long-term memory using RAG handler"""
if self.rag_handler is None:
return []
return self.rag_handler.search_memories(query, top_k, similarity_threshold)
def update_rag_config(self, **kwargs):
"""Update RAG configuration"""
if self.rag_handler is None:
logger.warning("RAG handler not initialized")
return
self.rag_handler.update_config(**kwargs)
def clear_rag_session(self):
"""Clear RAG session data"""
if self.rag_handler is None:
return
self.rag_handler.clear_session_data()
def enable_rag(self, long_term_memory: Any = None, config: RAGConfig = None):
"""Enable RAG functionality with optional memory store and configuration"""
if long_term_memory is not None:
self.long_term_memory = long_term_memory
if config is not None:
self.rag_config = config
self.rag_handler = AgentRAGHandler(
long_term_memory=self.long_term_memory,
config=self.rag_config,
agent_name=self.agent_name,
max_context_length=self.context_length,
verbose=self.verbose,
)
logger.info(f"RAG functionality enabled for agent: {self.agent_name}")
def disable_rag(self):
"""Disable RAG functionality"""
self.rag_handler = None
logger.info(f"RAG functionality disabled for agent: {self.agent_name}")
def is_rag_enabled(self) -> bool:
"""Check if RAG functionality is enabled"""
return self.rag_handler is not None and self.rag_handler.is_enabled()
def get_rag_config(self) -> Optional[RAGConfig]:
"""Get current RAG configuration"""
if self.rag_handler is None:
return None
return self.rag_handler.config
def set_rag_config(self, config: RAGConfig):
"""Set RAG configuration"""
if self.rag_handler is None:
logger.warning("RAG handler not initialized. Use enable_rag() first.")
return
self.rag_handler.config = config
logger.info("RAG configuration updated")
def save_to_rag_memory(self, content: str, metadata: Optional[Dict] = None, content_type: str = "manual"):
"""Manually save content to RAG memory"""
if self.rag_handler is None:
logger.warning("RAG handler not initialized. Use enable_rag() first.")
return False
return self.rag_handler.save_to_memory(content, metadata, content_type)
def query_rag_memory(self, query: str, context_type: str = "manual") -> str:
"""Manually query RAG memory"""
if self.rag_handler is None:
logger.warning("RAG handler not initialized. Use enable_rag() first.")
return ""
return self.rag_handler.query_memory(query, context_type)

@ -469,10 +469,19 @@ KEY INSIGHTS:
insights = [] insights = []
sentences = response.split(".") sentences = response.split(".")
# Ensure relevance_keywords is not None
keywords = self.config.relevance_keywords or [
"important",
"key",
"critical",
"summary",
"conclusion"
]
for sentence in sentences: for sentence in sentences:
if any( if any(
keyword in sentence.lower() keyword in sentence.lower()
for keyword in self.config.relevance_keywords[:5] for keyword in keywords[:5]
): ):
insights.append(sentence.strip()) insights.append(sentence.strip())

Loading…
Cancel
Save