vllm docs and cleanup

pull/811/head
Kye Gomez 2 weeks ago
parent 50d646d069
commit a8b51f3150

@ -0,0 +1,25 @@
from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.tools.mcp_integration import MCPServerSseParams
server_one = MCPServerSseParams(
url="http://127.0.0.1:6274",
headers={"Content-Type": "application/json"},
)
# Initialize the agent
agent = Agent(
agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
max_loops=1,
mcp_servers=[server_one],
output_type="final",
)
out = agent.run("Use the add tool to add 2 and 2")
print(type(out))

@ -7,6 +7,10 @@ black . && echo "✅ Code formatting complete!" || echo "❌ Black formatting fa
echo "🔍 Running Ruff linter..." echo "🔍 Running Ruff linter..."
ruff check . --fix && echo "✅ Linting complete!" || echo "❌ Linting failed!" ruff check . --fix && echo "✅ Linting complete!" || echo "❌ Linting failed!"
echo "Creating llm.txt file..."
python3 llm_txt.py && echo "✅ llm.txt file created!" || echo "❌ llm.txt file creation failed!"
echo "🏗️ Building package..." echo "🏗️ Building package..."
poetry build && echo "✅ Build successful!" || echo "❌ Build failed!" poetry build && echo "✅ Build successful!" || echo "❌ Build failed!"

File diff suppressed because it is too large Load Diff

@ -283,6 +283,7 @@ nav:
- Ollama: "swarms/examples/ollama.md" - Ollama: "swarms/examples/ollama.md"
- OpenRouter: "swarms/examples/openrouter.md" - OpenRouter: "swarms/examples/openrouter.md"
- XAI: "swarms/examples/xai.md" - XAI: "swarms/examples/xai.md"
- VLLM: "swarms/examples/vllm_integration.md"
- Swarms Tools: - Swarms Tools:
- Agent with Yahoo Finance: "swarms/examples/yahoo_finance.md" - Agent with Yahoo Finance: "swarms/examples/yahoo_finance.md"
- Twitter Agents: "swarms_tools/twitter.md" - Twitter Agents: "swarms_tools/twitter.md"
@ -299,6 +300,7 @@ nav:
- Group Chat Example: "swarms/examples/groupchat_example.md" - Group Chat Example: "swarms/examples/groupchat_example.md"
- Meme Agent Builder: "swarms/examples/meme_agents.md" - Meme Agent Builder: "swarms/examples/meme_agents.md"
- Sequential Workflow Example: "swarms/examples/sequential_example.md" - Sequential Workflow Example: "swarms/examples/sequential_example.md"
- ConcurrentWorkflow with VLLM Agents: "swarms/examples/vllm.md"
- External Agents: - External Agents:
- Swarms of Browser Agents: "swarms/examples/swarms_of_browser_agents.md" - Swarms of Browser Agents: "swarms/examples/swarms_of_browser_agents.md"
- Swarms UI: - Swarms UI:

@ -0,0 +1,429 @@
# VLLM Swarm Agents
!!! tip "Quick Summary"
This guide demonstrates how to create a sophisticated multi-agent system using VLLM and Swarms for comprehensive stock market analysis. You'll learn how to configure and orchestrate multiple AI agents working together to provide deep market insights.
## Overview
The example showcases how to build a stock analysis system with 5 specialized agents:
- Technical Analysis Agent
- Fundamental Analysis Agent
- Market Sentiment Agent
- Quantitative Strategy Agent
- Portfolio Strategy Agent
Each agent has specific expertise and works collaboratively through a concurrent workflow.
## Prerequisites
!!! warning "Requirements"
Before starting, ensure you have:
- Python 3.7 or higher
- The Swarms package installed
- Access to VLLM compatible models
- Sufficient compute resources for running VLLM
## Installation
!!! example "Setup Steps"
1. Install the Swarms package:
```bash
pip install swarms
```
2. Install VLLM dependencies (if not already installed):
```bash
pip install vllm
```
## Basic Usage
Here's a complete example of setting up the stock analysis swarm:
```python
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the VLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
)
```
!!! note "Model Selection"
The example uses Llama-2-7b-chat, but you can use any VLLM-compatible model. Make sure you have the necessary permissions and resources to run your chosen model.
## Agent Configuration
### Technical Analysis Agent
```python
technical_analyst = Agent(
agent_name="Technical-Analysis-Agent",
agent_description="Expert in technical analysis and chart patterns",
system_prompt="""You are an expert Technical Analysis Agent specializing in market technicals and chart patterns. Your responsibilities include:
1. PRICE ACTION ANALYSIS
- Identify key support and resistance levels
- Analyze price trends and momentum
- Detect chart patterns (e.g., head & shoulders, triangles, flags)
- Evaluate volume patterns and their implications
2. TECHNICAL INDICATORS
- Calculate and interpret moving averages (SMA, EMA)
- Analyze momentum indicators (RSI, MACD, Stochastic)
- Evaluate volume indicators (OBV, Volume Profile)
- Monitor volatility indicators (Bollinger Bands, ATR)
3. TRADING SIGNALS
- Generate clear buy/sell signals based on technical criteria
- Identify potential entry and exit points
- Set appropriate stop-loss and take-profit levels
- Calculate position sizing recommendations
4. RISK MANAGEMENT
- Assess market volatility and trend strength
- Identify potential reversal points
- Calculate risk/reward ratios for trades
- Suggest position sizing based on risk parameters
Your analysis should be data-driven, precise, and actionable. Always include specific price levels, time frames, and risk parameters in your recommendations.""",
max_loops=1,
llm=vllm,
)
```
!!! tip "Agent Customization"
Each agent can be customized with different:
- System prompts
- Temperature settings
- Max token limits
- Response formats
## Running the Swarm
To execute the swarm analysis:
```python
swarm = ConcurrentWorkflow(
name="Stock-Analysis-Swarm",
description="A swarm of agents that analyze stocks and provide comprehensive analysis.",
agents=stock_analysis_agents,
)
# Run the analysis
response = swarm.run("Analyze the best etfs for gold and other similar commodities in volatile markets")
```
## Full Code Example
```python
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the VLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
)
# Technical Analysis Agent
technical_analyst = Agent(
agent_name="Technical-Analysis-Agent",
agent_description="Expert in technical analysis and chart patterns",
system_prompt="""You are an expert Technical Analysis Agent specializing in market technicals and chart patterns. Your responsibilities include:
1. PRICE ACTION ANALYSIS
- Identify key support and resistance levels
- Analyze price trends and momentum
- Detect chart patterns (e.g., head & shoulders, triangles, flags)
- Evaluate volume patterns and their implications
2. TECHNICAL INDICATORS
- Calculate and interpret moving averages (SMA, EMA)
- Analyze momentum indicators (RSI, MACD, Stochastic)
- Evaluate volume indicators (OBV, Volume Profile)
- Monitor volatility indicators (Bollinger Bands, ATR)
3. TRADING SIGNALS
- Generate clear buy/sell signals based on technical criteria
- Identify potential entry and exit points
- Set appropriate stop-loss and take-profit levels
- Calculate position sizing recommendations
4. RISK MANAGEMENT
- Assess market volatility and trend strength
- Identify potential reversal points
- Calculate risk/reward ratios for trades
- Suggest position sizing based on risk parameters
Your analysis should be data-driven, precise, and actionable. Always include specific price levels, time frames, and risk parameters in your recommendations.""",
max_loops=1,
llm=vllm,
)
# Fundamental Analysis Agent
fundamental_analyst = Agent(
agent_name="Fundamental-Analysis-Agent",
agent_description="Expert in company fundamentals and valuation",
system_prompt="""You are an expert Fundamental Analysis Agent specializing in company valuation and financial metrics. Your core responsibilities include:
1. FINANCIAL STATEMENT ANALYSIS
- Analyze income statements, balance sheets, and cash flow statements
- Calculate and interpret key financial ratios
- Evaluate revenue growth and profit margins
- Assess company's debt levels and cash position
2. VALUATION METRICS
- Calculate fair value using multiple valuation methods:
* Discounted Cash Flow (DCF)
* Price-to-Earnings (P/E)
* Price-to-Book (P/B)
* Enterprise Value/EBITDA
- Compare valuations against industry peers
3. BUSINESS MODEL ASSESSMENT
- Evaluate competitive advantages and market position
- Analyze industry dynamics and market share
- Assess management quality and corporate governance
- Identify potential risks and growth opportunities
4. ECONOMIC CONTEXT
- Consider macroeconomic factors affecting the company
- Analyze industry cycles and trends
- Evaluate regulatory environment and compliance
- Assess global market conditions
Your analysis should be comprehensive, focusing on both quantitative metrics and qualitative factors that impact long-term value.""",
max_loops=1,
llm=vllm,
)
# Market Sentiment Agent
sentiment_analyst = Agent(
agent_name="Market-Sentiment-Agent",
agent_description="Expert in market psychology and sentiment analysis",
system_prompt="""You are an expert Market Sentiment Agent specializing in analyzing market psychology and investor behavior. Your key responsibilities include:
1. SENTIMENT INDICATORS
- Monitor and interpret market sentiment indicators:
* VIX (Fear Index)
* Put/Call Ratio
* Market Breadth
* Investor Surveys
- Track institutional vs retail investor behavior
2. NEWS AND SOCIAL MEDIA ANALYSIS
- Analyze news flow and media sentiment
- Monitor social media trends and discussions
- Track analyst recommendations and changes
- Evaluate corporate insider trading patterns
3. MARKET POSITIONING
- Assess hedge fund positioning and exposure
- Monitor short interest and short squeeze potential
- Track fund flows and asset allocation trends
- Analyze options market sentiment
4. CONTRARIAN SIGNALS
- Identify extreme sentiment readings
- Detect potential market turning points
- Analyze historical sentiment patterns
- Provide contrarian trading opportunities
Your analysis should combine quantitative sentiment metrics with qualitative assessment of market psychology and crowd behavior.""",
max_loops=1,
llm=vllm,
)
# Quantitative Strategy Agent
quant_analyst = Agent(
agent_name="Quantitative-Strategy-Agent",
agent_description="Expert in quantitative analysis and algorithmic strategies",
system_prompt="""You are an expert Quantitative Strategy Agent specializing in data-driven investment strategies. Your primary responsibilities include:
1. FACTOR ANALYSIS
- Analyze and monitor factor performance:
* Value
* Momentum
* Quality
* Size
* Low Volatility
- Calculate factor exposures and correlations
2. STATISTICAL ANALYSIS
- Perform statistical arbitrage analysis
- Calculate and monitor pair trading opportunities
- Analyze market anomalies and inefficiencies
- Develop mean reversion strategies
3. RISK MODELING
- Build and maintain risk models
- Calculate portfolio optimization metrics
- Monitor correlation matrices
- Analyze tail risk and stress scenarios
4. ALGORITHMIC STRATEGIES
- Develop systematic trading strategies
- Backtest and validate trading algorithms
- Monitor strategy performance metrics
- Optimize execution algorithms
Your analysis should be purely quantitative, based on statistical evidence and mathematical models rather than subjective opinions.""",
max_loops=1,
llm=vllm,
)
# Portfolio Strategy Agent
portfolio_strategist = Agent(
agent_name="Portfolio-Strategy-Agent",
agent_description="Expert in portfolio management and asset allocation",
system_prompt="""You are an expert Portfolio Strategy Agent specializing in portfolio construction and management. Your core responsibilities include:
1. ASSET ALLOCATION
- Develop strategic asset allocation frameworks
- Recommend tactical asset allocation shifts
- Optimize portfolio weightings
- Balance risk and return objectives
2. PORTFOLIO ANALYSIS
- Calculate portfolio risk metrics
- Monitor sector and factor exposures
- Analyze portfolio correlation matrix
- Track performance attribution
3. RISK MANAGEMENT
- Implement portfolio hedging strategies
- Monitor and adjust position sizing
- Set stop-loss and rebalancing rules
- Develop drawdown protection strategies
4. PORTFOLIO OPTIMIZATION
- Calculate efficient frontier analysis
- Optimize for various objectives:
* Maximum Sharpe Ratio
* Minimum Volatility
* Maximum Diversification
- Consider transaction costs and taxes
Your recommendations should focus on portfolio-level decisions that optimize risk-adjusted returns while meeting specific investment objectives.""",
max_loops=1,
llm=vllm,
)
# Create a list of all agents
stock_analysis_agents = [
technical_analyst,
fundamental_analyst,
sentiment_analyst,
quant_analyst,
portfolio_strategist
]
swarm = ConcurrentWorkflow(
name="Stock-Analysis-Swarm",
description="A swarm of agents that analyze stocks and provide a comprehensive analysis of the current trends and opportunities.",
agents=stock_analysis_agents,
)
swarm.run("Analyze the best etfs for gold and other similiar commodities in volatile markets")
```
## Best Practices
!!! success "Optimization Tips"
1. **Agent Design**
- Keep system prompts focused and specific
- Use clear role definitions
- Include error handling guidelines
2. **Resource Management**
- Monitor memory usage with large models
- Implement proper cleanup procedures
- Use batching for multiple queries
3. **Output Handling**
- Implement proper logging
- Format outputs consistently
- Include error checking
## Common Issues and Solutions
!!! warning "Troubleshooting"
Common issues you might encounter:
1. **Memory Issues**
- *Problem*: VLLM consuming too much memory
- *Solution*: Adjust batch sizes and model parameters
2. **Agent Coordination**
- *Problem*: Agents providing conflicting information
- *Solution*: Implement consensus mechanisms or priority rules
3. **Performance**
- *Problem*: Slow response times
- *Solution*: Use proper batching and optimize model loading
## FAQ
??? question "Can I use different models for different agents?"
Yes, you can initialize multiple VLLM wrappers with different models for each agent. However, be mindful of memory usage.
??? question "How many agents can run concurrently?"
The number depends on your hardware resources. Start with 3-5 agents and scale based on performance.
??? question "Can I customize agent communication patterns?"
Yes, you can modify the ConcurrentWorkflow class or create custom workflows for specific communication patterns.
## Advanced Configuration
!!! example "Extended Settings"
```python
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
temperature=0.7,
max_tokens=2048,
top_p=0.95,
)
```
## Contributing
!!! info "Get Involved"
We welcome contributions! Here's how you can help:
1. Report bugs and issues
2. Submit feature requests
3. Contribute to documentation
4. Share example use cases
## Resources
!!! abstract "Additional Reading"
- [VLLM Documentation](https://docs.vllm.ai/en/latest/)

@ -0,0 +1,194 @@
# vLLM Integration Guide
!!! info "Overview"
vLLM is a high-performance and easy-to-use library for LLM inference and serving. This guide explains how to integrate vLLM with Swarms for efficient, production-grade language model deployment.
## Installation
!!! note "Prerequisites"
Before you begin, make sure you have Python 3.8+ installed on your system.
=== "pip"
```bash
pip install -U vllm swarms
```
=== "poetry"
```bash
poetry add vllm swarms
```
## Basic Usage
Here's a simple example of how to use vLLM with Swarms:
```python title="basic_usage.py"
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the vLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
temperature=0.7,
max_tokens=4000
)
# Run inference
response = vllm.run("What is the capital of France?")
print(response)
```
## VLLMWrapper Class
!!! abstract "Class Overview"
The `VLLMWrapper` class provides a convenient interface for working with vLLM models.
### Key Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `model_name` | str | Name of the model to use | "meta-llama/Llama-2-7b-chat-hf" |
| `system_prompt` | str | System prompt to use | None |
| `stream` | bool | Whether to stream the output | False |
| `temperature` | float | Sampling temperature | 0.5 |
| `max_tokens` | int | Maximum number of tokens to generate | 4000 |
### Example with Custom Parameters
```python title="custom_parameters.py"
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-13b-chat-hf",
system_prompt="You are an expert in artificial intelligence.",
temperature=0.8,
max_tokens=2000
)
```
## Integration with Agents
You can easily integrate vLLM with Swarms agents for more complex workflows:
```python title="agent_integration.py"
from swarms import Agent
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize vLLM
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant."
)
# Create an agent with vLLM
agent = Agent(
agent_name="Research-Agent",
agent_description="Expert in conducting research and analysis",
system_prompt="""You are an expert research agent. Your tasks include:
1. Analyzing complex topics
2. Providing detailed summaries
3. Making data-driven recommendations""",
llm=vllm,
max_loops=1
)
# Run the agent
response = agent.run("Research the impact of AI on healthcare")
```
## Advanced Features
### Batch Processing
!!! tip "Performance Optimization"
Use batch processing for efficient handling of multiple tasks simultaneously.
```python title="batch_processing.py"
tasks = [
"What is machine learning?",
"Explain neural networks",
"Describe deep learning"
]
results = vllm.batched_run(tasks, batch_size=3)
```
### Error Handling
!!! warning "Error Management"
Always implement proper error handling in production environments.
```python title="error_handling.py"
from loguru import logger
try:
response = vllm.run("Complex task")
except Exception as error:
logger.error(f"Error occurred: {error}")
```
## Best Practices
!!! success "Recommended Practices"
=== "Model Selection"
- Choose appropriate model sizes based on your requirements
- Consider the trade-off between model size and inference speed
=== "System Resources"
- Ensure sufficient GPU memory for your chosen model
- Monitor resource usage during batch processing
=== "Prompt Engineering"
- Use clear and specific system prompts
- Structure user prompts for optimal results
=== "Error Handling"
- Implement proper error handling and logging
- Set up monitoring for production deployments
=== "Performance"
- Use batch processing for multiple tasks
- Adjust max_tokens based on your use case
- Fine-tune temperature for optimal output quality
## Example: Multi-Agent System
Here's an example of creating a multi-agent system using vLLM:
```python title="multi_agent_system.py"
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize vLLM
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant."
)
# Create specialized agents
research_agent = Agent(
agent_name="Research-Agent",
agent_description="Expert in research",
system_prompt="You are a research expert.",
llm=vllm
)
analysis_agent = Agent(
agent_name="Analysis-Agent",
agent_description="Expert in analysis",
system_prompt="You are an analysis expert.",
llm=vllm
)
# Create a workflow
agents = [research_agent, analysis_agent]
workflow = ConcurrentWorkflow(
name="Research-Analysis-Workflow",
description="Comprehensive research and analysis workflow",
agents=agents
)
# Run the workflow
result = workflow.run("Analyze the impact of renewable energy")
```

@ -1,5 +1,6 @@
from swarms.utils.vllm_wrapper import VLLMWrapper from swarms.utils.vllm_wrapper import VLLMWrapper
def main(): def main():
# Initialize the vLLM wrapper with a model # Initialize the vLLM wrapper with a model
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID # Note: You'll need to have the model downloaded or specify a HuggingFace model ID
@ -31,14 +32,15 @@ def main():
tasks = [ tasks = [
"What is vLLM?", "What is vLLM?",
"How does vLLM improve inference speed?", "How does vLLM improve inference speed?",
"What are the main features of vLLM?" "What are the main features of vLLM?",
] ]
responses = llm.batched_run(tasks, batch_size=2) responses = llm.batched_run(tasks, batch_size=2)
print("\nBatched responses:") print("\nBatched responses:")
for task, response in zip(tasks, responses): for task, response in zip(tasks, responses):
print(f"\nTask: {task}") print(f"\nTask: {task}")
print(f"Response: {response}") print(f"Response: {response}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

@ -0,0 +1,68 @@
import os
from pathlib import Path
def concat_all_md_files(root_dir, output_filename="llm.txt"):
"""
Recursively searches for all .md files in directory and subdirectories,
then concatenates them into a single output file.
Args:
root_dir (str): Root directory to search for .md files
output_filename (str): Name of output file (default: llm.txt)
Returns:
str: Path to the created output file
"""
try:
root_dir = Path(root_dir).resolve()
if not root_dir.is_dir():
raise ValueError(f"Directory not found: {root_dir}")
# Collect all .md files recursively
md_files = []
for root, _, files in os.walk(root_dir):
for file in files:
if file.lower().endswith(".md"):
full_path = Path(root) / file
md_files.append(full_path)
if not md_files:
print(
f"No .md files found in {root_dir} or its subdirectories"
)
return None
# Create output file in root directory
output_path = root_dir / output_filename
with open(output_path, "w", encoding="utf-8") as outfile:
for md_file in sorted(md_files):
try:
# Get relative path for header
rel_path = md_file.relative_to(root_dir)
with open(
md_file, "r", encoding="utf-8"
) as infile:
content = infile.read()
outfile.write(f"# File: {rel_path}\n\n")
outfile.write(content)
outfile.write(
"\n\n" + "-" * 50 + "\n\n"
) # Separator
except Exception as e:
print(f"Error processing {rel_path}: {str(e)}")
continue
print(
f"Created {output_path} with {len(md_files)} files merged"
)
return str(output_path)
except Exception as e:
print(f"Fatal error: {str(e)}")
return None
if __name__ == "__main__":
concat_all_md_files("docs")

@ -0,0 +1,20 @@
# math_server.py
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Math")
@mcp.tool()
def add(a: int, b: int) -> int:
"""Add two numbers"""
return a + b
@mcp.tool()
def multiply(a: int, b: int) -> int:
"""Multiply two numbers"""
return a * b
if __name__ == "__main__":
mcp.run(transport="sse")

@ -22,4 +22,4 @@ pytest>=8.1.1
networkx networkx
aiofiles aiofiles
httpx httpx
vllm>=0.2.0 # vllm>=0.2.0

@ -58,6 +58,12 @@ from swarms.utils.litellm_tokenizer import count_tokens
from swarms.utils.pdf_to_text import pdf_to_text from swarms.utils.pdf_to_text import pdf_to_text
from swarms.utils.str_to_dict import str_to_dict from swarms.utils.str_to_dict import str_to_dict
from swarms.tools.mcp_integration import (
batch_mcp_flow,
mcp_flow_get_tool_schema,
MCPServerSseParams,
)
# Utils # Utils
# Custom stopping condition # Custom stopping condition
@ -352,6 +358,7 @@ class Agent:
role: agent_roles = "worker", role: agent_roles = "worker",
no_print: bool = False, no_print: bool = False,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None, tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
mcp_servers: List[MCPServerSseParams] = [],
*args, *args,
**kwargs, **kwargs,
): ):
@ -471,6 +478,7 @@ class Agent:
self.role = role self.role = role
self.no_print = no_print self.no_print = no_print
self.tools_list_dictionary = tools_list_dictionary self.tools_list_dictionary = tools_list_dictionary
self.mcp_servers = mcp_servers
if ( if (
self.agent_name is not None self.agent_name is not None
@ -584,6 +592,12 @@ class Agent:
if self.llm is None: if self.llm is None:
self.llm = self.llm_handling() self.llm = self.llm_handling()
if (
self.tools_list_dictionary is None
and self.mcp_servers is not None
):
self.tools_list_dictionary = self.mcp_tool_handling()
def llm_handling(self): def llm_handling(self):
from swarms.utils.litellm_wrapper import LiteLLM from swarms.utils.litellm_wrapper import LiteLLM
@ -631,6 +645,69 @@ class Agent:
logger.error(f"Error in llm_handling: {e}") logger.error(f"Error in llm_handling: {e}")
return None return None
def mcp_execution_flow(self, response: any):
"""
Executes the MCP (Model Context Protocol) flow based on the provided response.
This method takes a response, converts it from a string to a dictionary format,
and checks for the presence of a tool name or a name in the response. If either
is found, it retrieves the tool name and proceeds to call the batch_mcp_flow
function to execute the corresponding tool actions.
Args:
response (any): The response to be processed, which can be in string format
that represents a dictionary.
Returns:
The output from the batch_mcp_flow function, which contains the results of
the tool execution. If an error occurs during processing, it logs the error
and returns None.
Raises:
Exception: Logs any exceptions that occur during the execution flow.
"""
try:
response = str_to_dict(response)
tool_output = batch_mcp_flow(
self.mcp_servers,
function_call=response,
)
return tool_output
except Exception as e:
logger.error(f"Error in mcp_execution_flow: {e}")
return None
def mcp_tool_handling(self):
"""
Handles the retrieval of tool schemas from the MCP servers.
This method iterates over the list of MCP servers, retrieves the tool schema
for each server using the mcp_flow_get_tool_schema function, and compiles
these schemas into a list. The resulting list is stored in the
tools_list_dictionary attribute.
Returns:
list: A list of tool schemas retrieved from the MCP servers. If an error
occurs during the retrieval process, it logs the error and returns None.
Raises:
Exception: Logs any exceptions that occur during the tool handling process.
"""
try:
self.tools_list_dictionary = []
for mcp_server in self.mcp_servers:
tool_schema = mcp_flow_get_tool_schema(mcp_server)
self.tools_list_dictionary.append(tool_schema)
print(self.tools_list_dictionary)
return self.tools_list_dictionary
except Exception as e:
logger.error(f"Error in mcp_tool_handling: {e}")
return None
def setup_config(self): def setup_config(self):
# The max_loops will be set dynamically if the dynamic_loop # The max_loops will be set dynamically if the dynamic_loop
if self.dynamic_loops is True: if self.dynamic_loops is True:

@ -1,554 +1,392 @@
from contextlib import AsyncExitStack from __future__ import annotations
from types import TracebackType
from typing import (
Any,
Callable,
Coroutine,
List,
Literal,
Optional,
TypedDict,
cast,
)
from mcp import ClientSession, StdioServerParameters from typing import Any, List
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ( from loguru import logger
CallToolResult,
EmbeddedResource, import abc
ImageContent, import asyncio
PromptMessage, from contextlib import AbstractAsyncContextManager, AsyncExitStack
TextContent, from pathlib import Path
from typing import Literal
from anyio.streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
) )
from mcp.types import ( from mcp import (
ClientSession,
StdioServerParameters,
Tool as MCPTool, Tool as MCPTool,
stdio_client,
) )
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from typing_extensions import NotRequired, TypedDict
from swarms.utils.any_to_str import any_to_str
def convert_mcp_prompt_message_to_message(
message: PromptMessage,
) -> str:
"""Convert an MCP prompt message to a string message.
Args: class MCPServer(abc.ABC):
message: MCP prompt message to convert """Base class for Model Context Protocol servers."""
Returns: @abc.abstractmethod
a string message async def connect(self):
""" """Connect to the server. For example, this might mean spawning a subprocess or
if message.content.type == "text": opening a network connection. The server is expected to remain connected until
if message.role == "user": `cleanup()` is called.
return str(message.content.text) """
elif message.role == "assistant": pass
return str(
message.content.text @property
) # Fixed attribute name from str to text @abc.abstractmethod
else: def name(self) -> str:
raise ValueError( """A readable name for the server."""
f"Unsupported prompt message role: {message.role}" pass
)
@abc.abstractmethod
async def cleanup(self):
"""Cleanup the server. For example, this might mean closing a subprocess or
closing a network connection.
"""
pass
raise ValueError( @abc.abstractmethod
f"Unsupported prompt message content type: {message.content.type}" async def list_tools(self) -> list[MCPTool]:
) """List the tools available on the server."""
pass
@abc.abstractmethod
async def call_tool(
self, tool_name: str, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
pass
async def load_mcp_prompt(
session: ClientSession,
name: str,
arguments: Optional[dict[str, Any]] = None,
) -> List[str]:
"""Load MCP prompt and convert to messages."""
response = await session.get_prompt(name, arguments)
return [ class _MCPServerWithClientSession(MCPServer, abc.ABC):
convert_mcp_prompt_message_to_message(message) """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
for message in response.messages
]
def __init__(self, cache_tools_list: bool):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
DEFAULT_ENCODING = "utf-8" # The cache is always dirty at startup, so that we fetch tools at least once
DEFAULT_ENCODING_ERROR_HANDLER = "strict" self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None
DEFAULT_HTTP_TIMEOUT = 5 @abc.abstractmethod
DEFAULT_SSE_READ_TIMEOUT = 60 * 5 def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
pass
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True
async def connect(self):
"""Connect to the server."""
try:
transport = await self.exit_stack.enter_async_context(
self.create_streams()
)
read, write = transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
class StdioConnection(TypedDict): # Return from cache if caching is enabled, we have tools, and the cache is not dirty
transport: Literal["stdio"] if (
self.cache_tools_list
and not self._cache_dirty
and self._tools_list
):
return self._tools_list
command: str # Reset the cache dirty to False
"""The executable to run to start the server.""" self._cache_dirty = False
args: list[str] # Fetch the tools from the server
"""Command line arguments to pass to the executable.""" self._tools_list = (await self.session.list_tools()).tools
return self._tools_list
env: dict[str, str] | None async def call_tool(
"""The environment to use when spawning the process.""" self, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
tool_name = arguments.get("tool_name") or arguments.get(
"name"
)
encoding: str if not tool_name:
"""The text encoding used when sending/receiving messages to the server.""" raise Exception("No tool name found in arguments")
encoding_error_handler: Literal["strict", "ignore", "replace"] if not self.session:
""" raise Exception(
The text encoding error handler. "Server not initialized. Make sure you call `connect()` first."
)
See https://docs.python.org/3/library/codecs.html#codec-base-classes for return await self.session.call_tool(tool_name, arguments)
explanations of possible values
"""
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logger.error(f"Error cleaning up server: {e}")
class SSEConnection(TypedDict):
transport: Literal["sse"]
url: str class MCPServerStdioParams(TypedDict):
"""The URL of the SSE endpoint to connect to.""" """Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
import.
"""
headers: dict[str, Any] | None command: str
"""HTTP headers to send to the SSE endpoint""" """The executable to run to start the server. For example, `python` or `node`."""
timeout: float args: NotRequired[list[str]]
"""HTTP timeout""" """Command line args to pass to the `command` executable. For example, `['foo.py']` or
`['server.js', '--port', '8080']`."""
sse_read_timeout: float env: NotRequired[dict[str, str]]
"""SSE read timeout""" """The environment variables to set for the server. ."""
cwd: NotRequired[str | Path]
"""The working directory to use when spawning the process."""
NonTextContent = ImageContent | EmbeddedResource encoding: NotRequired[str]
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
encoding_error_handler: NotRequired[
Literal["strict", "ignore", "replace"]
]
"""The text encoding error handler. Defaults to `strict`.
def _convert_call_tool_result( See https://docs.python.org/3/library/codecs.html#codec-base-classes for
call_tool_result: CallToolResult, explanations of possible values.
) -> tuple[str | list[str], list[NonTextContent] | None]: """
text_contents: list[TextContent] = []
non_text_contents = []
for content in call_tool_result.content:
if isinstance(content, TextContent):
text_contents.append(content)
else:
non_text_contents.append(content)
tool_content: str | list[str] = [
content.text for content in text_contents
]
if len(text_contents) == 1:
tool_content = tool_content[0]
if call_tool_result.isError: class MCPServerStdio(_MCPServerWithClientSession):
raise ValueError("Error calling tool") """MCP server implementation that uses the stdio transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
details.
"""
return tool_content, non_text_contents or None def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the stdio transport.
Args:
params: The params that configure the server. This includes the command to run to
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
command.
"""
super().__init__(cache_tools_list)
self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get(
"encoding_error_handler", "strict"
),
)
def convert_mcp_tool_to_function( self._name = name or f"stdio: {self.params.command}"
session: ClientSession,
tool: MCPTool,
) -> Callable[
...,
Coroutine[
Any, Any, tuple[str | list[str], list[NonTextContent] | None]
],
]:
"""Convert an MCP tool to a callable function.
NOTE: this tool can be executed only in a context of an active MCP client session. def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return stdio_client(self.params)
Args: @property
session: MCP client session def name(self) -> str:
tool: MCP tool to convert """A readable name for the server."""
return self._name
Returns:
a callable function
"""
async def call_tool( class MCPServerSseParams(TypedDict):
**arguments: dict[str, Any], """Mirrors the params in`mcp.client.sse.sse_client`."""
) -> tuple[str | list[str], list[NonTextContent] | None]:
"""Execute the tool with the given arguments."""
call_tool_result = await session.call_tool(
tool.name, arguments
)
return _convert_call_tool_result(call_tool_result)
# Add metadata as attributes to the function url: str
call_tool.__name__ = tool.name """The URL of the server."""
call_tool.__doc__ = tool.description or ""
call_tool.schema = tool.inputSchema
return call_tool headers: NotRequired[dict[str, str]]
"""The headers to send to the server."""
timeout: NotRequired[float]
"""The timeout for the HTTP request. Defaults to 5 seconds."""
async def load_mcp_tools(session: ClientSession) -> list[Callable]: sse_read_timeout: NotRequired[float]
"""Load all available MCP tools and convert them to callable functions.""" """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
tools = await session.list_tools()
return [
convert_mcp_tool_to_function(session, tool)
for tool in tools.tools
]
class MultiServerMCPClient: class MCPServerSse(_MCPServerWithClientSession):
"""Client for connecting to multiple MCP servers and loading tools from them.""" """MCP server implementation that uses the HTTP with SSE transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
for details.
"""
def __init__( def __init__(
self, self,
connections: dict[ params: MCPServerSseParams,
str, StdioConnection | SSEConnection cache_tools_list: bool = False,
] = None, name: str | None = None,
) -> None: ):
"""Initialize a MultiServerMCPClient with MCP servers connections. """Create a new MCP server based on the HTTP with SSE transport.
Args: Args:
connections: A dictionary mapping server names to connection configurations. params: The params that configure the server. This includes the URL of the server,
Each configuration can be either a StdioConnection or SSEConnection. the headers to send to the server, the timeout for the HTTP request, and the
If None, no initial connections are established. timeout for the SSE connection.
Example: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
```python fetched from the server on each call to `list_tools()`. The cache can be
async with MultiServerMCPClient( invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
{ if you know the server will not change its tools list, because it can drastically
"math": { improve latency (by avoiding a round-trip to the server every time).
"command": "python",
# Make sure to update to the full absolute path to your math_server.py file name: A readable name for the server. If not provided, we'll create one from the
"args": ["/path/to/math_server.py"], URL.
"transport": "stdio",
},
"weather": {
# make sure you start your weather server on port 8000
"url": "http://localhost:8000/sse",
"transport": "sse",
}
}
) as client:
all_tools = client.get_tools()
...
```
""" """
self.connections = connections super().__init__(cache_tools_list)
self.exit_stack = AsyncExitStack()
self.sessions: dict[str, ClientSession] = {}
self.server_name_to_tools: dict[str, list[Callable]] = {}
async def _initialize_session_and_load_tools( self.params = params
self, server_name: str, session: ClientSession self._name = name or f"sse: {self.params['url']}"
) -> None:
"""Initialize a session and load tools from it.
Args: def create_streams(
server_name: Name to identify this server connection self,
session: The ClientSession to initialize ) -> AbstractAsyncContextManager[
""" tuple[
# Initialize the session MemoryObjectReceiveStream[JSONRPCMessage | Exception],
await session.initialize() MemoryObjectSendStream[JSONRPCMessage],
self.sessions[server_name] = session ]
]:
"""Create the streams for the server."""
return sse_client(
url=self.params["url"],
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get(
"sse_read_timeout", 60 * 5
),
)
# Load tools from this server @property
server_tools = await load_mcp_tools(session) def name(self) -> str:
self.server_name_to_tools[server_name] = server_tools """A readable name for the server."""
return self._name
async def connect_to_server(
self,
server_name: str,
*,
transport: Literal["stdio", "sse"] = "stdio",
**kwargs,
) -> None:
"""Connect to an MCP server using either stdio or SSE.
This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse def mcp_flow_get_tool_schema(
based on the provided transport parameter. params: MCPServerSseParams,
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
Args: # Connect the server
server_name: Name to identify this server connection asyncio.run(server.connect())
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
**kwargs: Additional arguments to pass to the specific connection method
Raises: # Return the server
ValueError: If transport is not recognized output = asyncio.run(server.list_tools())
ValueError: If required parameters for the specified transport are missing
"""
if transport == "sse":
if "url" not in kwargs:
raise ValueError(
"'url' parameter is required for SSE connection"
)
await self.connect_to_server_via_sse(
server_name,
url=kwargs["url"],
headers=kwargs.get("headers"),
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
sse_read_timeout=kwargs.get(
"sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
),
)
elif transport == "stdio":
if "command" not in kwargs:
raise ValueError(
"'command' parameter is required for stdio connection"
)
if "args" not in kwargs:
raise ValueError(
"'args' parameter is required for stdio connection"
)
await self.connect_to_server_via_stdio(
server_name,
command=kwargs["command"],
args=kwargs["args"],
env=kwargs.get("env"),
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
encoding_error_handler=kwargs.get(
"encoding_error_handler",
DEFAULT_ENCODING_ERROR_HANDLER,
),
)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
async def connect_to_server_via_stdio( # Cleanup the server
self, asyncio.run(server.cleanup())
server_name: str,
*,
command: str,
args: list[str],
env: dict[str, str] | None = None,
encoding: str = DEFAULT_ENCODING,
encoding_error_handler: Literal[
"strict", "ignore", "replace"
] = DEFAULT_ENCODING_ERROR_HANDLER,
) -> None:
"""Connect to a specific MCP server using stdio
Args: return output.model_dump()
server_name: Name to identify this server connection
command: Command to execute
args: Arguments for the command
env: Environment variables for the command
encoding: Character encoding
encoding_error_handler: How to handle encoding errors
"""
server_params = StdioServerParameters(
command=command,
args=args,
env=env,
encoding=encoding,
encoding_error_handler=encoding_error_handler,
)
# Create and store the connection
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(
ClientSession(read, write)
),
)
await self._initialize_session_and_load_tools( def mcp_flow(
server_name, session params: MCPServerSseParams,
) function_call: dict[str, Any],
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
async def connect_to_server_via_sse( # Connect the server
self, asyncio.run(server.connect())
server_name: str,
*,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = DEFAULT_HTTP_TIMEOUT,
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
) -> None:
"""Connect to a specific MCP server using SSE
Args: # Return the server
server_name: Name to identify this server connection output = asyncio.run(server.call_tool(function_call))
url: URL of the SSE server
headers: HTTP headers to send to the SSE endpoint
timeout: HTTP timeout
sse_read_timeout: SSE read timeout
"""
# Create and store the connection
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout, sse_read_timeout)
)
read, write = sse_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(
ClientSession(read, write)
),
)
await self._initialize_session_and_load_tools( output = output.model_dump()
server_name, session
)
def get_tools(self) -> list[Callable]: # Cleanup the server
"""Get a list of all tools from all connected servers.""" asyncio.run(server.cleanup())
all_tools: list[Callable] = []
for server_tools in self.server_name_to_tools.values():
all_tools.extend(server_tools)
return all_tools
async def get_prompt( return any_to_str(output)
self,
server_name: str,
prompt_name: str,
arguments: Optional[dict[str, Any]] = None,
) -> List[str]:
"""Get a prompt from a given MCP server."""
session = self.sessions[server_name]
return await load_mcp_prompt(session, prompt_name, arguments)
async def __aenter__(self) -> "MultiServerMCPClient":
try:
connections = self.connections or {}
for server_name, connection in connections.items():
connection_dict = connection.copy()
transport = connection_dict.pop("transport")
if transport == "stdio":
await self.connect_to_server_via_stdio(
server_name, **connection_dict
)
elif transport == "sse":
await self.connect_to_server_via_sse(
server_name, **connection_dict
)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
return self
except Exception:
await self.exit_stack.aclose()
raise
async def __aexit__(
self, def batch_mcp_flow(
exc_type: type[BaseException] | None, params: List[MCPServerSseParams],
exc_val: BaseException | None, function_call: List[dict[str, Any]] = [],
exc_tb: TracebackType | None, ) -> MCPServer:
) -> None: output_list = []
await self.exit_stack.aclose()
for param in params:
output = mcp_flow(param, function_call)
# #!/usr/bin/env python3 output_list.append(output)
# import asyncio
# import os return output_list
# import json
# from typing import List, Any, Callable
# # # Import our MCP client module
# # from mcp_client import MultiServerMCPClient
# async def main():
# """Test script for demonstrating MCP client usage."""
# print("Starting MCP Client test...")
# # Create a connection to multiple MCP servers
# # You'll need to update these paths to match your setup
# async with MultiServerMCPClient(
# {
# "math": {
# "transport": "stdio",
# "command": "python",
# "args": ["/path/to/math_server.py"],
# "env": {"DEBUG": "1"},
# },
# "search": {
# "transport": "sse",
# "url": "http://localhost:8000/sse",
# "headers": {
# "Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
# },
# },
# }
# ) as client:
# # Get all available tools
# tools = client.get_tools()
# print(f"Found {len(tools)} tools across all servers")
# # Print tool information
# for i, tool in enumerate(tools):
# print(f"\nTool {i+1}: {tool.__name__}")
# print(f" Description: {tool.__doc__}")
# if hasattr(tool, "schema") and tool.schema:
# print(
# f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
# )
# # Example: Use a specific tool if available
# calculator_tool = next(
# (t for t in tools if t.__name__ == "calculator"), None
# )
# if calculator_tool:
# print("\n\nTesting calculator tool:")
# try:
# # Call the tool as an async function
# result, artifacts = await calculator_tool(
# expression="2 + 2 * 3"
# )
# print(f" Calculator result: {result}")
# if artifacts:
# print(
# f" With {len(artifacts)} additional artifacts"
# )
# except Exception as e:
# print(f" Error using calculator: {e}")
# # Example: Load a prompt from a server
# try:
# print("\n\nTesting prompt loading:")
# prompt_messages = await client.get_prompt(
# "math",
# "calculation_introduction",
# {"user_name": "Test User"},
# )
# print(
# f" Loaded prompt with {len(prompt_messages)} messages:"
# )
# for i, msg in enumerate(prompt_messages):
# print(f" Message {i+1}: {msg[:50]}...")
# except Exception as e:
# print(f" Error loading prompt: {e}")
# async def create_custom_tool():
# """Example of creating a custom tool function."""
# # Define a tool function with metadata
# async def add_numbers(a: float, b: float) -> tuple[str, None]:
# """Add two numbers together."""
# result = a + b
# return f"The sum of {a} and {b} is {result}", None
# # Add metadata to the function
# add_numbers.__name__ = "add_numbers"
# add_numbers.__doc__ = (
# "Add two numbers together and return the result."
# )
# add_numbers.schema = {
# "type": "object",
# "properties": {
# "a": {"type": "number", "description": "First number"},
# "b": {"type": "number", "description": "Second number"},
# },
# "required": ["a", "b"],
# }
# # Use the tool
# result, _ = await add_numbers(a=5, b=7)
# print(f"\nCustom tool result: {result}")
# if __name__ == "__main__":
# # Run both examples
# loop = asyncio.get_event_loop()
# loop.run_until_complete(main())
# loop.run_until_complete(create_custom_tool())

@ -6,11 +6,15 @@ try:
except ImportError: except ImportError:
import subprocess import subprocess
import sys import sys
print("Installing vllm") print("Installing vllm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"]) subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-U", "vllm"]
)
print("vllm installed") print("vllm installed")
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
class VLLMWrapper: class VLLMWrapper:
""" """
A wrapper class for vLLM that provides a similar interface to LiteLLM. A wrapper class for vLLM that provides a similar interface to LiteLLM.
@ -54,7 +58,7 @@ class VLLMWrapper:
self.tools_list_dictionary = tools_list_dictionary self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls self.parallel_tool_calls = parallel_tool_calls
# Initialize vLLM # Initialize vLLM
self.llm = LLM(model=model_name, **kwargs) self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams( self.sampling_params = SamplingParams(
@ -90,12 +94,12 @@ class VLLMWrapper:
""" """
try: try:
prompt = self._prepare_prompt(task) prompt = self._prepare_prompt(task)
outputs = self.llm.generate(prompt, self.sampling_params) outputs = self.llm.generate(prompt, self.sampling_params)
response = outputs[0].outputs[0].text.strip() response = outputs[0].outputs[0].text.strip()
return response return response
except Exception as error: except Exception as error:
logger.error(f"Error in VLLMWrapper: {error}") logger.error(f"Error in VLLMWrapper: {error}")
raise error raise error
@ -114,7 +118,9 @@ class VLLMWrapper:
""" """
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: def batched_run(
self, tasks: List[str], batch_size: int = 10
) -> List[str]:
""" """
Run the model for multiple tasks in batches. Run the model for multiple tasks in batches.
@ -125,14 +131,16 @@ class VLLMWrapper:
Returns: Returns:
List[str]: List of model responses. List[str]: List of model responses.
""" """
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
)
results = [] results = []
for i in range(0, len(tasks), batch_size): for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size] batch = tasks[i : i + batch_size]
for task in batch: for task in batch:
logger.info(f"Running task: {task}") logger.info(f"Running task: {task}")
results.append(self.run(task)) results.append(self.run(task))
logger.info("Completed all tasks.") logger.info("Completed all tasks.")
return results return results

@ -0,0 +1,212 @@
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
from dotenv import load_dotenv
load_dotenv()
# Initialize the VLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
)
# Technical Analysis Agent
technical_analyst = Agent(
agent_name="Technical-Analysis-Agent",
agent_description="Expert in technical analysis and chart patterns",
system_prompt="""You are an expert Technical Analysis Agent specializing in market technicals and chart patterns. Your responsibilities include:
1. PRICE ACTION ANALYSIS
- Identify key support and resistance levels
- Analyze price trends and momentum
- Detect chart patterns (e.g., head & shoulders, triangles, flags)
- Evaluate volume patterns and their implications
2. TECHNICAL INDICATORS
- Calculate and interpret moving averages (SMA, EMA)
- Analyze momentum indicators (RSI, MACD, Stochastic)
- Evaluate volume indicators (OBV, Volume Profile)
- Monitor volatility indicators (Bollinger Bands, ATR)
3. TRADING SIGNALS
- Generate clear buy/sell signals based on technical criteria
- Identify potential entry and exit points
- Set appropriate stop-loss and take-profit levels
- Calculate position sizing recommendations
4. RISK MANAGEMENT
- Assess market volatility and trend strength
- Identify potential reversal points
- Calculate risk/reward ratios for trades
- Suggest position sizing based on risk parameters
Your analysis should be data-driven, precise, and actionable. Always include specific price levels, time frames, and risk parameters in your recommendations.""",
max_loops=1,
llm=vllm,
)
# Fundamental Analysis Agent
fundamental_analyst = Agent(
agent_name="Fundamental-Analysis-Agent",
agent_description="Expert in company fundamentals and valuation",
system_prompt="""You are an expert Fundamental Analysis Agent specializing in company valuation and financial metrics. Your core responsibilities include:
1. FINANCIAL STATEMENT ANALYSIS
- Analyze income statements, balance sheets, and cash flow statements
- Calculate and interpret key financial ratios
- Evaluate revenue growth and profit margins
- Assess company's debt levels and cash position
2. VALUATION METRICS
- Calculate fair value using multiple valuation methods:
* Discounted Cash Flow (DCF)
* Price-to-Earnings (P/E)
* Price-to-Book (P/B)
* Enterprise Value/EBITDA
- Compare valuations against industry peers
3. BUSINESS MODEL ASSESSMENT
- Evaluate competitive advantages and market position
- Analyze industry dynamics and market share
- Assess management quality and corporate governance
- Identify potential risks and growth opportunities
4. ECONOMIC CONTEXT
- Consider macroeconomic factors affecting the company
- Analyze industry cycles and trends
- Evaluate regulatory environment and compliance
- Assess global market conditions
Your analysis should be comprehensive, focusing on both quantitative metrics and qualitative factors that impact long-term value.""",
max_loops=1,
llm=vllm,
)
# Market Sentiment Agent
sentiment_analyst = Agent(
agent_name="Market-Sentiment-Agent",
agent_description="Expert in market psychology and sentiment analysis",
system_prompt="""You are an expert Market Sentiment Agent specializing in analyzing market psychology and investor behavior. Your key responsibilities include:
1. SENTIMENT INDICATORS
- Monitor and interpret market sentiment indicators:
* VIX (Fear Index)
* Put/Call Ratio
* Market Breadth
* Investor Surveys
- Track institutional vs retail investor behavior
2. NEWS AND SOCIAL MEDIA ANALYSIS
- Analyze news flow and media sentiment
- Monitor social media trends and discussions
- Track analyst recommendations and changes
- Evaluate corporate insider trading patterns
3. MARKET POSITIONING
- Assess hedge fund positioning and exposure
- Monitor short interest and short squeeze potential
- Track fund flows and asset allocation trends
- Analyze options market sentiment
4. CONTRARIAN SIGNALS
- Identify extreme sentiment readings
- Detect potential market turning points
- Analyze historical sentiment patterns
- Provide contrarian trading opportunities
Your analysis should combine quantitative sentiment metrics with qualitative assessment of market psychology and crowd behavior.""",
max_loops=1,
llm=vllm,
)
# Quantitative Strategy Agent
quant_analyst = Agent(
agent_name="Quantitative-Strategy-Agent",
agent_description="Expert in quantitative analysis and algorithmic strategies",
system_prompt="""You are an expert Quantitative Strategy Agent specializing in data-driven investment strategies. Your primary responsibilities include:
1. FACTOR ANALYSIS
- Analyze and monitor factor performance:
* Value
* Momentum
* Quality
* Size
* Low Volatility
- Calculate factor exposures and correlations
2. STATISTICAL ANALYSIS
- Perform statistical arbitrage analysis
- Calculate and monitor pair trading opportunities
- Analyze market anomalies and inefficiencies
- Develop mean reversion strategies
3. RISK MODELING
- Build and maintain risk models
- Calculate portfolio optimization metrics
- Monitor correlation matrices
- Analyze tail risk and stress scenarios
4. ALGORITHMIC STRATEGIES
- Develop systematic trading strategies
- Backtest and validate trading algorithms
- Monitor strategy performance metrics
- Optimize execution algorithms
Your analysis should be purely quantitative, based on statistical evidence and mathematical models rather than subjective opinions.""",
max_loops=1,
llm=vllm,
)
# Portfolio Strategy Agent
portfolio_strategist = Agent(
agent_name="Portfolio-Strategy-Agent",
agent_description="Expert in portfolio management and asset allocation",
system_prompt="""You are an expert Portfolio Strategy Agent specializing in portfolio construction and management. Your core responsibilities include:
1. ASSET ALLOCATION
- Develop strategic asset allocation frameworks
- Recommend tactical asset allocation shifts
- Optimize portfolio weightings
- Balance risk and return objectives
2. PORTFOLIO ANALYSIS
- Calculate portfolio risk metrics
- Monitor sector and factor exposures
- Analyze portfolio correlation matrix
- Track performance attribution
3. RISK MANAGEMENT
- Implement portfolio hedging strategies
- Monitor and adjust position sizing
- Set stop-loss and rebalancing rules
- Develop drawdown protection strategies
4. PORTFOLIO OPTIMIZATION
- Calculate efficient frontier analysis
- Optimize for various objectives:
* Maximum Sharpe Ratio
* Minimum Volatility
* Maximum Diversification
- Consider transaction costs and taxes
Your recommendations should focus on portfolio-level decisions that optimize risk-adjusted returns while meeting specific investment objectives.""",
max_loops=1,
llm=vllm,
)
# Create a list of all agents
stock_analysis_agents = [
technical_analyst,
fundamental_analyst,
sentiment_analyst,
quant_analyst,
portfolio_strategist
]
swarm = ConcurrentWorkflow(
name="Stock-Analysis-Swarm",
description="A swarm of agents that analyze stocks and provide a comprehensive analysis of the current trends and opportunities.",
agents=stock_analysis_agents,
)
swarm.run("Analyze the best etfs for gold and other similiar commodities in volatile markets")
Loading…
Cancel
Save