vllm docs and cleanup

pull/770/merge
Kye Gomez 4 days ago
parent 50d646d069
commit a8b51f3150

@ -0,0 +1,25 @@
from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.tools.mcp_integration import MCPServerSseParams
server_one = MCPServerSseParams(
url="http://127.0.0.1:6274",
headers={"Content-Type": "application/json"},
)
# Initialize the agent
agent = Agent(
agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
max_loops=1,
mcp_servers=[server_one],
output_type="final",
)
out = agent.run("Use the add tool to add 2 and 2")
print(type(out))

@ -7,6 +7,10 @@ black . && echo "✅ Code formatting complete!" || echo "❌ Black formatting fa
echo "🔍 Running Ruff linter..."
ruff check . --fix && echo "✅ Linting complete!" || echo "❌ Linting failed!"
echo "Creating llm.txt file..."
python3 llm_txt.py && echo "✅ llm.txt file created!" || echo "❌ llm.txt file creation failed!"
echo "🏗️ Building package..."
poetry build && echo "✅ Build successful!" || echo "❌ Build failed!"

File diff suppressed because it is too large Load Diff

@ -283,6 +283,7 @@ nav:
- Ollama: "swarms/examples/ollama.md"
- OpenRouter: "swarms/examples/openrouter.md"
- XAI: "swarms/examples/xai.md"
- VLLM: "swarms/examples/vllm_integration.md"
- Swarms Tools:
- Agent with Yahoo Finance: "swarms/examples/yahoo_finance.md"
- Twitter Agents: "swarms_tools/twitter.md"
@ -299,6 +300,7 @@ nav:
- Group Chat Example: "swarms/examples/groupchat_example.md"
- Meme Agent Builder: "swarms/examples/meme_agents.md"
- Sequential Workflow Example: "swarms/examples/sequential_example.md"
- ConcurrentWorkflow with VLLM Agents: "swarms/examples/vllm.md"
- External Agents:
- Swarms of Browser Agents: "swarms/examples/swarms_of_browser_agents.md"
- Swarms UI:

@ -0,0 +1,429 @@
# VLLM Swarm Agents
!!! tip "Quick Summary"
This guide demonstrates how to create a sophisticated multi-agent system using VLLM and Swarms for comprehensive stock market analysis. You'll learn how to configure and orchestrate multiple AI agents working together to provide deep market insights.
## Overview
The example showcases how to build a stock analysis system with 5 specialized agents:
- Technical Analysis Agent
- Fundamental Analysis Agent
- Market Sentiment Agent
- Quantitative Strategy Agent
- Portfolio Strategy Agent
Each agent has specific expertise and works collaboratively through a concurrent workflow.
## Prerequisites
!!! warning "Requirements"
Before starting, ensure you have:
- Python 3.7 or higher
- The Swarms package installed
- Access to VLLM compatible models
- Sufficient compute resources for running VLLM
## Installation
!!! example "Setup Steps"
1. Install the Swarms package:
```bash
pip install swarms
```
2. Install VLLM dependencies (if not already installed):
```bash
pip install vllm
```
## Basic Usage
Here's a complete example of setting up the stock analysis swarm:
```python
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the VLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
)
```
!!! note "Model Selection"
The example uses Llama-2-7b-chat, but you can use any VLLM-compatible model. Make sure you have the necessary permissions and resources to run your chosen model.
## Agent Configuration
### Technical Analysis Agent
```python
technical_analyst = Agent(
agent_name="Technical-Analysis-Agent",
agent_description="Expert in technical analysis and chart patterns",
system_prompt="""You are an expert Technical Analysis Agent specializing in market technicals and chart patterns. Your responsibilities include:
1. PRICE ACTION ANALYSIS
- Identify key support and resistance levels
- Analyze price trends and momentum
- Detect chart patterns (e.g., head & shoulders, triangles, flags)
- Evaluate volume patterns and their implications
2. TECHNICAL INDICATORS
- Calculate and interpret moving averages (SMA, EMA)
- Analyze momentum indicators (RSI, MACD, Stochastic)
- Evaluate volume indicators (OBV, Volume Profile)
- Monitor volatility indicators (Bollinger Bands, ATR)
3. TRADING SIGNALS
- Generate clear buy/sell signals based on technical criteria
- Identify potential entry and exit points
- Set appropriate stop-loss and take-profit levels
- Calculate position sizing recommendations
4. RISK MANAGEMENT
- Assess market volatility and trend strength
- Identify potential reversal points
- Calculate risk/reward ratios for trades
- Suggest position sizing based on risk parameters
Your analysis should be data-driven, precise, and actionable. Always include specific price levels, time frames, and risk parameters in your recommendations.""",
max_loops=1,
llm=vllm,
)
```
!!! tip "Agent Customization"
Each agent can be customized with different:
- System prompts
- Temperature settings
- Max token limits
- Response formats
## Running the Swarm
To execute the swarm analysis:
```python
swarm = ConcurrentWorkflow(
name="Stock-Analysis-Swarm",
description="A swarm of agents that analyze stocks and provide comprehensive analysis.",
agents=stock_analysis_agents,
)
# Run the analysis
response = swarm.run("Analyze the best etfs for gold and other similar commodities in volatile markets")
```
## Full Code Example
```python
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the VLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
)
# Technical Analysis Agent
technical_analyst = Agent(
agent_name="Technical-Analysis-Agent",
agent_description="Expert in technical analysis and chart patterns",
system_prompt="""You are an expert Technical Analysis Agent specializing in market technicals and chart patterns. Your responsibilities include:
1. PRICE ACTION ANALYSIS
- Identify key support and resistance levels
- Analyze price trends and momentum
- Detect chart patterns (e.g., head & shoulders, triangles, flags)
- Evaluate volume patterns and their implications
2. TECHNICAL INDICATORS
- Calculate and interpret moving averages (SMA, EMA)
- Analyze momentum indicators (RSI, MACD, Stochastic)
- Evaluate volume indicators (OBV, Volume Profile)
- Monitor volatility indicators (Bollinger Bands, ATR)
3. TRADING SIGNALS
- Generate clear buy/sell signals based on technical criteria
- Identify potential entry and exit points
- Set appropriate stop-loss and take-profit levels
- Calculate position sizing recommendations
4. RISK MANAGEMENT
- Assess market volatility and trend strength
- Identify potential reversal points
- Calculate risk/reward ratios for trades
- Suggest position sizing based on risk parameters
Your analysis should be data-driven, precise, and actionable. Always include specific price levels, time frames, and risk parameters in your recommendations.""",
max_loops=1,
llm=vllm,
)
# Fundamental Analysis Agent
fundamental_analyst = Agent(
agent_name="Fundamental-Analysis-Agent",
agent_description="Expert in company fundamentals and valuation",
system_prompt="""You are an expert Fundamental Analysis Agent specializing in company valuation and financial metrics. Your core responsibilities include:
1. FINANCIAL STATEMENT ANALYSIS
- Analyze income statements, balance sheets, and cash flow statements
- Calculate and interpret key financial ratios
- Evaluate revenue growth and profit margins
- Assess company's debt levels and cash position
2. VALUATION METRICS
- Calculate fair value using multiple valuation methods:
* Discounted Cash Flow (DCF)
* Price-to-Earnings (P/E)
* Price-to-Book (P/B)
* Enterprise Value/EBITDA
- Compare valuations against industry peers
3. BUSINESS MODEL ASSESSMENT
- Evaluate competitive advantages and market position
- Analyze industry dynamics and market share
- Assess management quality and corporate governance
- Identify potential risks and growth opportunities
4. ECONOMIC CONTEXT
- Consider macroeconomic factors affecting the company
- Analyze industry cycles and trends
- Evaluate regulatory environment and compliance
- Assess global market conditions
Your analysis should be comprehensive, focusing on both quantitative metrics and qualitative factors that impact long-term value.""",
max_loops=1,
llm=vllm,
)
# Market Sentiment Agent
sentiment_analyst = Agent(
agent_name="Market-Sentiment-Agent",
agent_description="Expert in market psychology and sentiment analysis",
system_prompt="""You are an expert Market Sentiment Agent specializing in analyzing market psychology and investor behavior. Your key responsibilities include:
1. SENTIMENT INDICATORS
- Monitor and interpret market sentiment indicators:
* VIX (Fear Index)
* Put/Call Ratio
* Market Breadth
* Investor Surveys
- Track institutional vs retail investor behavior
2. NEWS AND SOCIAL MEDIA ANALYSIS
- Analyze news flow and media sentiment
- Monitor social media trends and discussions
- Track analyst recommendations and changes
- Evaluate corporate insider trading patterns
3. MARKET POSITIONING
- Assess hedge fund positioning and exposure
- Monitor short interest and short squeeze potential
- Track fund flows and asset allocation trends
- Analyze options market sentiment
4. CONTRARIAN SIGNALS
- Identify extreme sentiment readings
- Detect potential market turning points
- Analyze historical sentiment patterns
- Provide contrarian trading opportunities
Your analysis should combine quantitative sentiment metrics with qualitative assessment of market psychology and crowd behavior.""",
max_loops=1,
llm=vllm,
)
# Quantitative Strategy Agent
quant_analyst = Agent(
agent_name="Quantitative-Strategy-Agent",
agent_description="Expert in quantitative analysis and algorithmic strategies",
system_prompt="""You are an expert Quantitative Strategy Agent specializing in data-driven investment strategies. Your primary responsibilities include:
1. FACTOR ANALYSIS
- Analyze and monitor factor performance:
* Value
* Momentum
* Quality
* Size
* Low Volatility
- Calculate factor exposures and correlations
2. STATISTICAL ANALYSIS
- Perform statistical arbitrage analysis
- Calculate and monitor pair trading opportunities
- Analyze market anomalies and inefficiencies
- Develop mean reversion strategies
3. RISK MODELING
- Build and maintain risk models
- Calculate portfolio optimization metrics
- Monitor correlation matrices
- Analyze tail risk and stress scenarios
4. ALGORITHMIC STRATEGIES
- Develop systematic trading strategies
- Backtest and validate trading algorithms
- Monitor strategy performance metrics
- Optimize execution algorithms
Your analysis should be purely quantitative, based on statistical evidence and mathematical models rather than subjective opinions.""",
max_loops=1,
llm=vllm,
)
# Portfolio Strategy Agent
portfolio_strategist = Agent(
agent_name="Portfolio-Strategy-Agent",
agent_description="Expert in portfolio management and asset allocation",
system_prompt="""You are an expert Portfolio Strategy Agent specializing in portfolio construction and management. Your core responsibilities include:
1. ASSET ALLOCATION
- Develop strategic asset allocation frameworks
- Recommend tactical asset allocation shifts
- Optimize portfolio weightings
- Balance risk and return objectives
2. PORTFOLIO ANALYSIS
- Calculate portfolio risk metrics
- Monitor sector and factor exposures
- Analyze portfolio correlation matrix
- Track performance attribution
3. RISK MANAGEMENT
- Implement portfolio hedging strategies
- Monitor and adjust position sizing
- Set stop-loss and rebalancing rules
- Develop drawdown protection strategies
4. PORTFOLIO OPTIMIZATION
- Calculate efficient frontier analysis
- Optimize for various objectives:
* Maximum Sharpe Ratio
* Minimum Volatility
* Maximum Diversification
- Consider transaction costs and taxes
Your recommendations should focus on portfolio-level decisions that optimize risk-adjusted returns while meeting specific investment objectives.""",
max_loops=1,
llm=vllm,
)
# Create a list of all agents
stock_analysis_agents = [
technical_analyst,
fundamental_analyst,
sentiment_analyst,
quant_analyst,
portfolio_strategist
]
swarm = ConcurrentWorkflow(
name="Stock-Analysis-Swarm",
description="A swarm of agents that analyze stocks and provide a comprehensive analysis of the current trends and opportunities.",
agents=stock_analysis_agents,
)
swarm.run("Analyze the best etfs for gold and other similiar commodities in volatile markets")
```
## Best Practices
!!! success "Optimization Tips"
1. **Agent Design**
- Keep system prompts focused and specific
- Use clear role definitions
- Include error handling guidelines
2. **Resource Management**
- Monitor memory usage with large models
- Implement proper cleanup procedures
- Use batching for multiple queries
3. **Output Handling**
- Implement proper logging
- Format outputs consistently
- Include error checking
## Common Issues and Solutions
!!! warning "Troubleshooting"
Common issues you might encounter:
1. **Memory Issues**
- *Problem*: VLLM consuming too much memory
- *Solution*: Adjust batch sizes and model parameters
2. **Agent Coordination**
- *Problem*: Agents providing conflicting information
- *Solution*: Implement consensus mechanisms or priority rules
3. **Performance**
- *Problem*: Slow response times
- *Solution*: Use proper batching and optimize model loading
## FAQ
??? question "Can I use different models for different agents?"
Yes, you can initialize multiple VLLM wrappers with different models for each agent. However, be mindful of memory usage.
??? question "How many agents can run concurrently?"
The number depends on your hardware resources. Start with 3-5 agents and scale based on performance.
??? question "Can I customize agent communication patterns?"
Yes, you can modify the ConcurrentWorkflow class or create custom workflows for specific communication patterns.
## Advanced Configuration
!!! example "Extended Settings"
```python
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
temperature=0.7,
max_tokens=2048,
top_p=0.95,
)
```
## Contributing
!!! info "Get Involved"
We welcome contributions! Here's how you can help:
1. Report bugs and issues
2. Submit feature requests
3. Contribute to documentation
4. Share example use cases
## Resources
!!! abstract "Additional Reading"
- [VLLM Documentation](https://docs.vllm.ai/en/latest/)

@ -0,0 +1,194 @@
# vLLM Integration Guide
!!! info "Overview"
vLLM is a high-performance and easy-to-use library for LLM inference and serving. This guide explains how to integrate vLLM with Swarms for efficient, production-grade language model deployment.
## Installation
!!! note "Prerequisites"
Before you begin, make sure you have Python 3.8+ installed on your system.
=== "pip"
```bash
pip install -U vllm swarms
```
=== "poetry"
```bash
poetry add vllm swarms
```
## Basic Usage
Here's a simple example of how to use vLLM with Swarms:
```python title="basic_usage.py"
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the vLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
temperature=0.7,
max_tokens=4000
)
# Run inference
response = vllm.run("What is the capital of France?")
print(response)
```
## VLLMWrapper Class
!!! abstract "Class Overview"
The `VLLMWrapper` class provides a convenient interface for working with vLLM models.
### Key Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `model_name` | str | Name of the model to use | "meta-llama/Llama-2-7b-chat-hf" |
| `system_prompt` | str | System prompt to use | None |
| `stream` | bool | Whether to stream the output | False |
| `temperature` | float | Sampling temperature | 0.5 |
| `max_tokens` | int | Maximum number of tokens to generate | 4000 |
### Example with Custom Parameters
```python title="custom_parameters.py"
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-13b-chat-hf",
system_prompt="You are an expert in artificial intelligence.",
temperature=0.8,
max_tokens=2000
)
```
## Integration with Agents
You can easily integrate vLLM with Swarms agents for more complex workflows:
```python title="agent_integration.py"
from swarms import Agent
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize vLLM
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant."
)
# Create an agent with vLLM
agent = Agent(
agent_name="Research-Agent",
agent_description="Expert in conducting research and analysis",
system_prompt="""You are an expert research agent. Your tasks include:
1. Analyzing complex topics
2. Providing detailed summaries
3. Making data-driven recommendations""",
llm=vllm,
max_loops=1
)
# Run the agent
response = agent.run("Research the impact of AI on healthcare")
```
## Advanced Features
### Batch Processing
!!! tip "Performance Optimization"
Use batch processing for efficient handling of multiple tasks simultaneously.
```python title="batch_processing.py"
tasks = [
"What is machine learning?",
"Explain neural networks",
"Describe deep learning"
]
results = vllm.batched_run(tasks, batch_size=3)
```
### Error Handling
!!! warning "Error Management"
Always implement proper error handling in production environments.
```python title="error_handling.py"
from loguru import logger
try:
response = vllm.run("Complex task")
except Exception as error:
logger.error(f"Error occurred: {error}")
```
## Best Practices
!!! success "Recommended Practices"
=== "Model Selection"
- Choose appropriate model sizes based on your requirements
- Consider the trade-off between model size and inference speed
=== "System Resources"
- Ensure sufficient GPU memory for your chosen model
- Monitor resource usage during batch processing
=== "Prompt Engineering"
- Use clear and specific system prompts
- Structure user prompts for optimal results
=== "Error Handling"
- Implement proper error handling and logging
- Set up monitoring for production deployments
=== "Performance"
- Use batch processing for multiple tasks
- Adjust max_tokens based on your use case
- Fine-tune temperature for optimal output quality
## Example: Multi-Agent System
Here's an example of creating a multi-agent system using vLLM:
```python title="multi_agent_system.py"
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize vLLM
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant."
)
# Create specialized agents
research_agent = Agent(
agent_name="Research-Agent",
agent_description="Expert in research",
system_prompt="You are a research expert.",
llm=vllm
)
analysis_agent = Agent(
agent_name="Analysis-Agent",
agent_description="Expert in analysis",
system_prompt="You are an analysis expert.",
llm=vllm
)
# Create a workflow
agents = [research_agent, analysis_agent]
workflow = ConcurrentWorkflow(
name="Research-Analysis-Workflow",
description="Comprehensive research and analysis workflow",
agents=agents
)
# Run the workflow
result = workflow.run("Analyze the impact of renewable energy")
```

@ -1,5 +1,6 @@
from swarms.utils.vllm_wrapper import VLLMWrapper
def main():
# Initialize the vLLM wrapper with a model
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID
@ -31,14 +32,15 @@ def main():
tasks = [
"What is vLLM?",
"How does vLLM improve inference speed?",
"What are the main features of vLLM?"
"What are the main features of vLLM?",
]
responses = llm.batched_run(tasks, batch_size=2)
print("\nBatched responses:")
for task, response in zip(tasks, responses):
print(f"\nTask: {task}")
print(f"Response: {response}")
if __name__ == "__main__":
main()
main()

@ -0,0 +1,68 @@
import os
from pathlib import Path
def concat_all_md_files(root_dir, output_filename="llm.txt"):
"""
Recursively searches for all .md files in directory and subdirectories,
then concatenates them into a single output file.
Args:
root_dir (str): Root directory to search for .md files
output_filename (str): Name of output file (default: llm.txt)
Returns:
str: Path to the created output file
"""
try:
root_dir = Path(root_dir).resolve()
if not root_dir.is_dir():
raise ValueError(f"Directory not found: {root_dir}")
# Collect all .md files recursively
md_files = []
for root, _, files in os.walk(root_dir):
for file in files:
if file.lower().endswith(".md"):
full_path = Path(root) / file
md_files.append(full_path)
if not md_files:
print(
f"No .md files found in {root_dir} or its subdirectories"
)
return None
# Create output file in root directory
output_path = root_dir / output_filename
with open(output_path, "w", encoding="utf-8") as outfile:
for md_file in sorted(md_files):
try:
# Get relative path for header
rel_path = md_file.relative_to(root_dir)
with open(
md_file, "r", encoding="utf-8"
) as infile:
content = infile.read()
outfile.write(f"# File: {rel_path}\n\n")
outfile.write(content)
outfile.write(
"\n\n" + "-" * 50 + "\n\n"
) # Separator
except Exception as e:
print(f"Error processing {rel_path}: {str(e)}")
continue
print(
f"Created {output_path} with {len(md_files)} files merged"
)
return str(output_path)
except Exception as e:
print(f"Fatal error: {str(e)}")
return None
if __name__ == "__main__":
concat_all_md_files("docs")

@ -0,0 +1,20 @@
# math_server.py
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Math")
@mcp.tool()
def add(a: int, b: int) -> int:
"""Add two numbers"""
return a + b
@mcp.tool()
def multiply(a: int, b: int) -> int:
"""Multiply two numbers"""
return a * b
if __name__ == "__main__":
mcp.run(transport="sse")

@ -22,4 +22,4 @@ pytest>=8.1.1
networkx
aiofiles
httpx
vllm>=0.2.0
# vllm>=0.2.0

@ -58,6 +58,12 @@ from swarms.utils.litellm_tokenizer import count_tokens
from swarms.utils.pdf_to_text import pdf_to_text
from swarms.utils.str_to_dict import str_to_dict
from swarms.tools.mcp_integration import (
batch_mcp_flow,
mcp_flow_get_tool_schema,
MCPServerSseParams,
)
# Utils
# Custom stopping condition
@ -352,6 +358,7 @@ class Agent:
role: agent_roles = "worker",
no_print: bool = False,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
mcp_servers: List[MCPServerSseParams] = [],
*args,
**kwargs,
):
@ -471,6 +478,7 @@ class Agent:
self.role = role
self.no_print = no_print
self.tools_list_dictionary = tools_list_dictionary
self.mcp_servers = mcp_servers
if (
self.agent_name is not None
@ -584,6 +592,12 @@ class Agent:
if self.llm is None:
self.llm = self.llm_handling()
if (
self.tools_list_dictionary is None
and self.mcp_servers is not None
):
self.tools_list_dictionary = self.mcp_tool_handling()
def llm_handling(self):
from swarms.utils.litellm_wrapper import LiteLLM
@ -631,6 +645,69 @@ class Agent:
logger.error(f"Error in llm_handling: {e}")
return None
def mcp_execution_flow(self, response: any):
"""
Executes the MCP (Model Context Protocol) flow based on the provided response.
This method takes a response, converts it from a string to a dictionary format,
and checks for the presence of a tool name or a name in the response. If either
is found, it retrieves the tool name and proceeds to call the batch_mcp_flow
function to execute the corresponding tool actions.
Args:
response (any): The response to be processed, which can be in string format
that represents a dictionary.
Returns:
The output from the batch_mcp_flow function, which contains the results of
the tool execution. If an error occurs during processing, it logs the error
and returns None.
Raises:
Exception: Logs any exceptions that occur during the execution flow.
"""
try:
response = str_to_dict(response)
tool_output = batch_mcp_flow(
self.mcp_servers,
function_call=response,
)
return tool_output
except Exception as e:
logger.error(f"Error in mcp_execution_flow: {e}")
return None
def mcp_tool_handling(self):
"""
Handles the retrieval of tool schemas from the MCP servers.
This method iterates over the list of MCP servers, retrieves the tool schema
for each server using the mcp_flow_get_tool_schema function, and compiles
these schemas into a list. The resulting list is stored in the
tools_list_dictionary attribute.
Returns:
list: A list of tool schemas retrieved from the MCP servers. If an error
occurs during the retrieval process, it logs the error and returns None.
Raises:
Exception: Logs any exceptions that occur during the tool handling process.
"""
try:
self.tools_list_dictionary = []
for mcp_server in self.mcp_servers:
tool_schema = mcp_flow_get_tool_schema(mcp_server)
self.tools_list_dictionary.append(tool_schema)
print(self.tools_list_dictionary)
return self.tools_list_dictionary
except Exception as e:
logger.error(f"Error in mcp_tool_handling: {e}")
return None
def setup_config(self):
# The max_loops will be set dynamically if the dynamic_loop
if self.dynamic_loops is True:

@ -1,554 +1,392 @@
from contextlib import AsyncExitStack
from types import TracebackType
from typing import (
Any,
Callable,
Coroutine,
List,
Literal,
Optional,
TypedDict,
cast,
)
from __future__ import annotations
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
PromptMessage,
TextContent,
from typing import Any, List
from loguru import logger
import abc
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path
from typing import Literal
from anyio.streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
)
from mcp.types import (
from mcp import (
ClientSession,
StdioServerParameters,
Tool as MCPTool,
stdio_client,
)
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from typing_extensions import NotRequired, TypedDict
from swarms.utils.any_to_str import any_to_str
def convert_mcp_prompt_message_to_message(
message: PromptMessage,
) -> str:
"""Convert an MCP prompt message to a string message.
Args:
message: MCP prompt message to convert
class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""
Returns:
a string message
"""
if message.content.type == "text":
if message.role == "user":
return str(message.content.text)
elif message.role == "assistant":
return str(
message.content.text
) # Fixed attribute name from str to text
else:
raise ValueError(
f"Unsupported prompt message role: {message.role}"
)
@abc.abstractmethod
async def connect(self):
"""Connect to the server. For example, this might mean spawning a subprocess or
opening a network connection. The server is expected to remain connected until
`cleanup()` is called.
"""
pass
@property
@abc.abstractmethod
def name(self) -> str:
"""A readable name for the server."""
pass
@abc.abstractmethod
async def cleanup(self):
"""Cleanup the server. For example, this might mean closing a subprocess or
closing a network connection.
"""
pass
raise ValueError(
f"Unsupported prompt message content type: {message.content.type}"
)
@abc.abstractmethod
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
pass
@abc.abstractmethod
async def call_tool(
self, tool_name: str, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
pass
async def load_mcp_prompt(
session: ClientSession,
name: str,
arguments: Optional[dict[str, Any]] = None,
) -> List[str]:
"""Load MCP prompt and convert to messages."""
response = await session.get_prompt(name, arguments)
return [
convert_mcp_prompt_message_to_message(message)
for message in response.messages
]
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
def __init__(self, cache_tools_list: bool):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
DEFAULT_ENCODING = "utf-8"
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None
DEFAULT_HTTP_TIMEOUT = 5
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
@abc.abstractmethod
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
pass
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True
async def connect(self):
"""Connect to the server."""
try:
transport = await self.exit_stack.enter_async_context(
self.create_streams()
)
read, write = transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
class StdioConnection(TypedDict):
transport: Literal["stdio"]
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if (
self.cache_tools_list
and not self._cache_dirty
and self._tools_list
):
return self._tools_list
command: str
"""The executable to run to start the server."""
# Reset the cache dirty to False
self._cache_dirty = False
args: list[str]
"""Command line arguments to pass to the executable."""
# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list
env: dict[str, str] | None
"""The environment to use when spawning the process."""
async def call_tool(
self, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
tool_name = arguments.get("tool_name") or arguments.get(
"name"
)
encoding: str
"""The text encoding used when sending/receiving messages to the server."""
if not tool_name:
raise Exception("No tool name found in arguments")
encoding_error_handler: Literal["strict", "ignore", "replace"]
"""
The text encoding error handler.
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""
return await self.session.call_tool(tool_name, arguments)
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logger.error(f"Error cleaning up server: {e}")
class SSEConnection(TypedDict):
transport: Literal["sse"]
url: str
"""The URL of the SSE endpoint to connect to."""
class MCPServerStdioParams(TypedDict):
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
import.
"""
headers: dict[str, Any] | None
"""HTTP headers to send to the SSE endpoint"""
command: str
"""The executable to run to start the server. For example, `python` or `node`."""
timeout: float
"""HTTP timeout"""
args: NotRequired[list[str]]
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
`['server.js', '--port', '8080']`."""
sse_read_timeout: float
"""SSE read timeout"""
env: NotRequired[dict[str, str]]
"""The environment variables to set for the server. ."""
cwd: NotRequired[str | Path]
"""The working directory to use when spawning the process."""
NonTextContent = ImageContent | EmbeddedResource
encoding: NotRequired[str]
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
encoding_error_handler: NotRequired[
Literal["strict", "ignore", "replace"]
]
"""The text encoding error handler. Defaults to `strict`.
def _convert_call_tool_result(
call_tool_result: CallToolResult,
) -> tuple[str | list[str], list[NonTextContent] | None]:
text_contents: list[TextContent] = []
non_text_contents = []
for content in call_tool_result.content:
if isinstance(content, TextContent):
text_contents.append(content)
else:
non_text_contents.append(content)
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values.
"""
tool_content: str | list[str] = [
content.text for content in text_contents
]
if len(text_contents) == 1:
tool_content = tool_content[0]
if call_tool_result.isError:
raise ValueError("Error calling tool")
class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server implementation that uses the stdio transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
details.
"""
return tool_content, non_text_contents or None
def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the stdio transport.
Args:
params: The params that configure the server. This includes the command to run to
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
command.
"""
super().__init__(cache_tools_list)
self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get(
"encoding_error_handler", "strict"
),
)
def convert_mcp_tool_to_function(
session: ClientSession,
tool: MCPTool,
) -> Callable[
...,
Coroutine[
Any, Any, tuple[str | list[str], list[NonTextContent] | None]
],
]:
"""Convert an MCP tool to a callable function.
self._name = name or f"stdio: {self.params.command}"
NOTE: this tool can be executed only in a context of an active MCP client session.
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return stdio_client(self.params)
Args:
session: MCP client session
tool: MCP tool to convert
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
Returns:
a callable function
"""
async def call_tool(
**arguments: dict[str, Any],
) -> tuple[str | list[str], list[NonTextContent] | None]:
"""Execute the tool with the given arguments."""
call_tool_result = await session.call_tool(
tool.name, arguments
)
return _convert_call_tool_result(call_tool_result)
class MCPServerSseParams(TypedDict):
"""Mirrors the params in`mcp.client.sse.sse_client`."""
# Add metadata as attributes to the function
call_tool.__name__ = tool.name
call_tool.__doc__ = tool.description or ""
call_tool.schema = tool.inputSchema
url: str
"""The URL of the server."""
return call_tool
headers: NotRequired[dict[str, str]]
"""The headers to send to the server."""
timeout: NotRequired[float]
"""The timeout for the HTTP request. Defaults to 5 seconds."""
async def load_mcp_tools(session: ClientSession) -> list[Callable]:
"""Load all available MCP tools and convert them to callable functions."""
tools = await session.list_tools()
return [
convert_mcp_tool_to_function(session, tool)
for tool in tools.tools
]
sse_read_timeout: NotRequired[float]
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
class MultiServerMCPClient:
"""Client for connecting to multiple MCP servers and loading tools from them."""
class MCPServerSse(_MCPServerWithClientSession):
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
for details.
"""
def __init__(
self,
connections: dict[
str, StdioConnection | SSEConnection
] = None,
) -> None:
"""Initialize a MultiServerMCPClient with MCP servers connections.
params: MCPServerSseParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the HTTP with SSE transport.
Args:
connections: A dictionary mapping server names to connection configurations.
Each configuration can be either a StdioConnection or SSEConnection.
If None, no initial connections are established.
Example:
```python
async with MultiServerMCPClient(
{
"math": {
"command": "python",
# Make sure to update to the full absolute path to your math_server.py file
"args": ["/path/to/math_server.py"],
"transport": "stdio",
},
"weather": {
# make sure you start your weather server on port 8000
"url": "http://localhost:8000/sse",
"transport": "sse",
}
}
) as client:
all_tools = client.get_tools()
...
```
params: The params that configure the server. This includes the URL of the server,
the headers to send to the server, the timeout for the HTTP request, and the
timeout for the SSE connection.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
URL.
"""
self.connections = connections
self.exit_stack = AsyncExitStack()
self.sessions: dict[str, ClientSession] = {}
self.server_name_to_tools: dict[str, list[Callable]] = {}
super().__init__(cache_tools_list)
async def _initialize_session_and_load_tools(
self, server_name: str, session: ClientSession
) -> None:
"""Initialize a session and load tools from it.
self.params = params
self._name = name or f"sse: {self.params['url']}"
Args:
server_name: Name to identify this server connection
session: The ClientSession to initialize
"""
# Initialize the session
await session.initialize()
self.sessions[server_name] = session
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return sse_client(
url=self.params["url"],
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get(
"sse_read_timeout", 60 * 5
),
)
# Load tools from this server
server_tools = await load_mcp_tools(session)
self.server_name_to_tools[server_name] = server_tools
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
async def connect_to_server(
self,
server_name: str,
*,
transport: Literal["stdio", "sse"] = "stdio",
**kwargs,
) -> None:
"""Connect to an MCP server using either stdio or SSE.
This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
based on the provided transport parameter.
def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
Args:
server_name: Name to identify this server connection
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
**kwargs: Additional arguments to pass to the specific connection method
# Connect the server
asyncio.run(server.connect())
Raises:
ValueError: If transport is not recognized
ValueError: If required parameters for the specified transport are missing
"""
if transport == "sse":
if "url" not in kwargs:
raise ValueError(
"'url' parameter is required for SSE connection"
)
await self.connect_to_server_via_sse(
server_name,
url=kwargs["url"],
headers=kwargs.get("headers"),
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
sse_read_timeout=kwargs.get(
"sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
),
)
elif transport == "stdio":
if "command" not in kwargs:
raise ValueError(
"'command' parameter is required for stdio connection"
)
if "args" not in kwargs:
raise ValueError(
"'args' parameter is required for stdio connection"
)
await self.connect_to_server_via_stdio(
server_name,
command=kwargs["command"],
args=kwargs["args"],
env=kwargs.get("env"),
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
encoding_error_handler=kwargs.get(
"encoding_error_handler",
DEFAULT_ENCODING_ERROR_HANDLER,
),
)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
# Return the server
output = asyncio.run(server.list_tools())
async def connect_to_server_via_stdio(
self,
server_name: str,
*,
command: str,
args: list[str],
env: dict[str, str] | None = None,
encoding: str = DEFAULT_ENCODING,
encoding_error_handler: Literal[
"strict", "ignore", "replace"
] = DEFAULT_ENCODING_ERROR_HANDLER,
) -> None:
"""Connect to a specific MCP server using stdio
# Cleanup the server
asyncio.run(server.cleanup())
Args:
server_name: Name to identify this server connection
command: Command to execute
args: Arguments for the command
env: Environment variables for the command
encoding: Character encoding
encoding_error_handler: How to handle encoding errors
"""
server_params = StdioServerParameters(
command=command,
args=args,
env=env,
encoding=encoding,
encoding_error_handler=encoding_error_handler,
)
return output.model_dump()
# Create and store the connection
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(
ClientSession(read, write)
),
)
await self._initialize_session_and_load_tools(
server_name, session
)
def mcp_flow(
params: MCPServerSseParams,
function_call: dict[str, Any],
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
async def connect_to_server_via_sse(
self,
server_name: str,
*,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = DEFAULT_HTTP_TIMEOUT,
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
) -> None:
"""Connect to a specific MCP server using SSE
# Connect the server
asyncio.run(server.connect())
Args:
server_name: Name to identify this server connection
url: URL of the SSE server
headers: HTTP headers to send to the SSE endpoint
timeout: HTTP timeout
sse_read_timeout: SSE read timeout
"""
# Create and store the connection
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout, sse_read_timeout)
)
read, write = sse_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(
ClientSession(read, write)
),
)
# Return the server
output = asyncio.run(server.call_tool(function_call))
await self._initialize_session_and_load_tools(
server_name, session
)
output = output.model_dump()
def get_tools(self) -> list[Callable]:
"""Get a list of all tools from all connected servers."""
all_tools: list[Callable] = []
for server_tools in self.server_name_to_tools.values():
all_tools.extend(server_tools)
return all_tools
# Cleanup the server
asyncio.run(server.cleanup())
async def get_prompt(
self,
server_name: str,
prompt_name: str,
arguments: Optional[dict[str, Any]] = None,
) -> List[str]:
"""Get a prompt from a given MCP server."""
session = self.sessions[server_name]
return await load_mcp_prompt(session, prompt_name, arguments)
async def __aenter__(self) -> "MultiServerMCPClient":
try:
connections = self.connections or {}
for server_name, connection in connections.items():
connection_dict = connection.copy()
transport = connection_dict.pop("transport")
if transport == "stdio":
await self.connect_to_server_via_stdio(
server_name, **connection_dict
)
elif transport == "sse":
await self.connect_to_server_via_sse(
server_name, **connection_dict
)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
return self
except Exception:
await self.exit_stack.aclose()
raise
return any_to_str(output)
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.exit_stack.aclose()
# #!/usr/bin/env python3
# import asyncio
# import os
# import json
# from typing import List, Any, Callable
# # # Import our MCP client module
# # from mcp_client import MultiServerMCPClient
# async def main():
# """Test script for demonstrating MCP client usage."""
# print("Starting MCP Client test...")
# # Create a connection to multiple MCP servers
# # You'll need to update these paths to match your setup
# async with MultiServerMCPClient(
# {
# "math": {
# "transport": "stdio",
# "command": "python",
# "args": ["/path/to/math_server.py"],
# "env": {"DEBUG": "1"},
# },
# "search": {
# "transport": "sse",
# "url": "http://localhost:8000/sse",
# "headers": {
# "Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
# },
# },
# }
# ) as client:
# # Get all available tools
# tools = client.get_tools()
# print(f"Found {len(tools)} tools across all servers")
# # Print tool information
# for i, tool in enumerate(tools):
# print(f"\nTool {i+1}: {tool.__name__}")
# print(f" Description: {tool.__doc__}")
# if hasattr(tool, "schema") and tool.schema:
# print(
# f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
# )
# # Example: Use a specific tool if available
# calculator_tool = next(
# (t for t in tools if t.__name__ == "calculator"), None
# )
# if calculator_tool:
# print("\n\nTesting calculator tool:")
# try:
# # Call the tool as an async function
# result, artifacts = await calculator_tool(
# expression="2 + 2 * 3"
# )
# print(f" Calculator result: {result}")
# if artifacts:
# print(
# f" With {len(artifacts)} additional artifacts"
# )
# except Exception as e:
# print(f" Error using calculator: {e}")
# # Example: Load a prompt from a server
# try:
# print("\n\nTesting prompt loading:")
# prompt_messages = await client.get_prompt(
# "math",
# "calculation_introduction",
# {"user_name": "Test User"},
# )
# print(
# f" Loaded prompt with {len(prompt_messages)} messages:"
# )
# for i, msg in enumerate(prompt_messages):
# print(f" Message {i+1}: {msg[:50]}...")
# except Exception as e:
# print(f" Error loading prompt: {e}")
# async def create_custom_tool():
# """Example of creating a custom tool function."""
# # Define a tool function with metadata
# async def add_numbers(a: float, b: float) -> tuple[str, None]:
# """Add two numbers together."""
# result = a + b
# return f"The sum of {a} and {b} is {result}", None
# # Add metadata to the function
# add_numbers.__name__ = "add_numbers"
# add_numbers.__doc__ = (
# "Add two numbers together and return the result."
# )
# add_numbers.schema = {
# "type": "object",
# "properties": {
# "a": {"type": "number", "description": "First number"},
# "b": {"type": "number", "description": "Second number"},
# },
# "required": ["a", "b"],
# }
# # Use the tool
# result, _ = await add_numbers(a=5, b=7)
# print(f"\nCustom tool result: {result}")
# if __name__ == "__main__":
# # Run both examples
# loop = asyncio.get_event_loop()
# loop.run_until_complete(main())
# loop.run_until_complete(create_custom_tool())
def batch_mcp_flow(
params: List[MCPServerSseParams],
function_call: List[dict[str, Any]] = [],
) -> MCPServer:
output_list = []
for param in params:
output = mcp_flow(param, function_call)
output_list.append(output)
return output_list

@ -6,11 +6,15 @@ try:
except ImportError:
import subprocess
import sys
print("Installing vllm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"])
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-U", "vllm"]
)
print("vllm installed")
from vllm import LLM, SamplingParams
class VLLMWrapper:
"""
A wrapper class for vLLM that provides a similar interface to LiteLLM.
@ -54,7 +58,7 @@ class VLLMWrapper:
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
# Initialize vLLM
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
@ -90,12 +94,12 @@ class VLLMWrapper:
"""
try:
prompt = self._prepare_prompt(task)
outputs = self.llm.generate(prompt, self.sampling_params)
response = outputs[0].outputs[0].text.strip()
return response
except Exception as error:
logger.error(f"Error in VLLMWrapper: {error}")
raise error
@ -114,7 +118,9 @@ class VLLMWrapper:
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
def batched_run(
self, tasks: List[str], batch_size: int = 10
) -> List[str]:
"""
Run the model for multiple tasks in batches.
@ -125,14 +131,16 @@ class VLLMWrapper:
Returns:
List[str]: List of model responses.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
)
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
batch = tasks[i : i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
results.append(self.run(task))
logger.info("Completed all tasks.")
return results
return results

@ -0,0 +1,212 @@
from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper
from dotenv import load_dotenv
load_dotenv()
# Initialize the VLLM wrapper
vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.",
)
# Technical Analysis Agent
technical_analyst = Agent(
agent_name="Technical-Analysis-Agent",
agent_description="Expert in technical analysis and chart patterns",
system_prompt="""You are an expert Technical Analysis Agent specializing in market technicals and chart patterns. Your responsibilities include:
1. PRICE ACTION ANALYSIS
- Identify key support and resistance levels
- Analyze price trends and momentum
- Detect chart patterns (e.g., head & shoulders, triangles, flags)
- Evaluate volume patterns and their implications
2. TECHNICAL INDICATORS
- Calculate and interpret moving averages (SMA, EMA)
- Analyze momentum indicators (RSI, MACD, Stochastic)
- Evaluate volume indicators (OBV, Volume Profile)
- Monitor volatility indicators (Bollinger Bands, ATR)
3. TRADING SIGNALS
- Generate clear buy/sell signals based on technical criteria
- Identify potential entry and exit points
- Set appropriate stop-loss and take-profit levels
- Calculate position sizing recommendations
4. RISK MANAGEMENT
- Assess market volatility and trend strength
- Identify potential reversal points
- Calculate risk/reward ratios for trades
- Suggest position sizing based on risk parameters
Your analysis should be data-driven, precise, and actionable. Always include specific price levels, time frames, and risk parameters in your recommendations.""",
max_loops=1,
llm=vllm,
)
# Fundamental Analysis Agent
fundamental_analyst = Agent(
agent_name="Fundamental-Analysis-Agent",
agent_description="Expert in company fundamentals and valuation",
system_prompt="""You are an expert Fundamental Analysis Agent specializing in company valuation and financial metrics. Your core responsibilities include:
1. FINANCIAL STATEMENT ANALYSIS
- Analyze income statements, balance sheets, and cash flow statements
- Calculate and interpret key financial ratios
- Evaluate revenue growth and profit margins
- Assess company's debt levels and cash position
2. VALUATION METRICS
- Calculate fair value using multiple valuation methods:
* Discounted Cash Flow (DCF)
* Price-to-Earnings (P/E)
* Price-to-Book (P/B)
* Enterprise Value/EBITDA
- Compare valuations against industry peers
3. BUSINESS MODEL ASSESSMENT
- Evaluate competitive advantages and market position
- Analyze industry dynamics and market share
- Assess management quality and corporate governance
- Identify potential risks and growth opportunities
4. ECONOMIC CONTEXT
- Consider macroeconomic factors affecting the company
- Analyze industry cycles and trends
- Evaluate regulatory environment and compliance
- Assess global market conditions
Your analysis should be comprehensive, focusing on both quantitative metrics and qualitative factors that impact long-term value.""",
max_loops=1,
llm=vllm,
)
# Market Sentiment Agent
sentiment_analyst = Agent(
agent_name="Market-Sentiment-Agent",
agent_description="Expert in market psychology and sentiment analysis",
system_prompt="""You are an expert Market Sentiment Agent specializing in analyzing market psychology and investor behavior. Your key responsibilities include:
1. SENTIMENT INDICATORS
- Monitor and interpret market sentiment indicators:
* VIX (Fear Index)
* Put/Call Ratio
* Market Breadth
* Investor Surveys
- Track institutional vs retail investor behavior
2. NEWS AND SOCIAL MEDIA ANALYSIS
- Analyze news flow and media sentiment
- Monitor social media trends and discussions
- Track analyst recommendations and changes
- Evaluate corporate insider trading patterns
3. MARKET POSITIONING
- Assess hedge fund positioning and exposure
- Monitor short interest and short squeeze potential
- Track fund flows and asset allocation trends
- Analyze options market sentiment
4. CONTRARIAN SIGNALS
- Identify extreme sentiment readings
- Detect potential market turning points
- Analyze historical sentiment patterns
- Provide contrarian trading opportunities
Your analysis should combine quantitative sentiment metrics with qualitative assessment of market psychology and crowd behavior.""",
max_loops=1,
llm=vllm,
)
# Quantitative Strategy Agent
quant_analyst = Agent(
agent_name="Quantitative-Strategy-Agent",
agent_description="Expert in quantitative analysis and algorithmic strategies",
system_prompt="""You are an expert Quantitative Strategy Agent specializing in data-driven investment strategies. Your primary responsibilities include:
1. FACTOR ANALYSIS
- Analyze and monitor factor performance:
* Value
* Momentum
* Quality
* Size
* Low Volatility
- Calculate factor exposures and correlations
2. STATISTICAL ANALYSIS
- Perform statistical arbitrage analysis
- Calculate and monitor pair trading opportunities
- Analyze market anomalies and inefficiencies
- Develop mean reversion strategies
3. RISK MODELING
- Build and maintain risk models
- Calculate portfolio optimization metrics
- Monitor correlation matrices
- Analyze tail risk and stress scenarios
4. ALGORITHMIC STRATEGIES
- Develop systematic trading strategies
- Backtest and validate trading algorithms
- Monitor strategy performance metrics
- Optimize execution algorithms
Your analysis should be purely quantitative, based on statistical evidence and mathematical models rather than subjective opinions.""",
max_loops=1,
llm=vllm,
)
# Portfolio Strategy Agent
portfolio_strategist = Agent(
agent_name="Portfolio-Strategy-Agent",
agent_description="Expert in portfolio management and asset allocation",
system_prompt="""You are an expert Portfolio Strategy Agent specializing in portfolio construction and management. Your core responsibilities include:
1. ASSET ALLOCATION
- Develop strategic asset allocation frameworks
- Recommend tactical asset allocation shifts
- Optimize portfolio weightings
- Balance risk and return objectives
2. PORTFOLIO ANALYSIS
- Calculate portfolio risk metrics
- Monitor sector and factor exposures
- Analyze portfolio correlation matrix
- Track performance attribution
3. RISK MANAGEMENT
- Implement portfolio hedging strategies
- Monitor and adjust position sizing
- Set stop-loss and rebalancing rules
- Develop drawdown protection strategies
4. PORTFOLIO OPTIMIZATION
- Calculate efficient frontier analysis
- Optimize for various objectives:
* Maximum Sharpe Ratio
* Minimum Volatility
* Maximum Diversification
- Consider transaction costs and taxes
Your recommendations should focus on portfolio-level decisions that optimize risk-adjusted returns while meeting specific investment objectives.""",
max_loops=1,
llm=vllm,
)
# Create a list of all agents
stock_analysis_agents = [
technical_analyst,
fundamental_analyst,
sentiment_analyst,
quant_analyst,
portfolio_strategist
]
swarm = ConcurrentWorkflow(
name="Stock-Analysis-Swarm",
description="A swarm of agents that analyze stocks and provide a comprehensive analysis of the current trends and opportunities.",
agents=stock_analysis_agents,
)
swarm.run("Analyze the best etfs for gold and other similiar commodities in volatile markets")
Loading…
Cancel
Save