parent
a673ba2d71
commit
bb7e18f654
@ -1,387 +0,0 @@
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import faiss
|
||||
import pickle
|
||||
import time
|
||||
from loguru import logger
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""Represents a document in the HQD-RAG system.
|
||||
|
||||
Attributes:
|
||||
id (str): Unique identifier for the document
|
||||
content (str): Raw text content of the document
|
||||
embedding (Optional[np.ndarray]): Quantum-inspired embedding vector
|
||||
cluster_id (Optional[int]): ID of the cluster this document belongs to
|
||||
"""
|
||||
id: str
|
||||
content: str
|
||||
embedding: Optional[np.ndarray] = None
|
||||
cluster_id: Optional[int] = None
|
||||
|
||||
class HQDRAG:
|
||||
"""
|
||||
Hierarchical Quantum-Inspired Distributed RAG (HQD-RAG) System
|
||||
|
||||
A production-grade implementation of the HQD-RAG algorithm for ultra-fast
|
||||
and reliable document retrieval. Uses quantum-inspired embeddings and
|
||||
hierarchical clustering for efficient search.
|
||||
|
||||
Attributes:
|
||||
embedding_dim (int): Dimension of the quantum-inspired embeddings
|
||||
num_clusters (int): Number of hierarchical clusters
|
||||
similarity_threshold (float): Threshold for quantum similarity matching
|
||||
reliability_threshold (float): Threshold for reliability verification
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
num_clusters: int = 128,
|
||||
similarity_threshold: float = 0.75,
|
||||
reliability_threshold: float = 0.85,
|
||||
model_name: str = "all-MiniLM-L6-v2"
|
||||
):
|
||||
"""Initialize the HQD-RAG system.
|
||||
|
||||
Args:
|
||||
embedding_dim: Dimension of document embeddings
|
||||
num_clusters: Number of clusters for hierarchical organization
|
||||
similarity_threshold: Minimum similarity score for retrieval
|
||||
reliability_threshold: Minimum reliability score for verification
|
||||
model_name: Name of the sentence transformer model to use
|
||||
"""
|
||||
logger.info(f"Initializing HQD-RAG with {embedding_dim} dimensions")
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_clusters = num_clusters
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.reliability_threshold = reliability_threshold
|
||||
|
||||
# Initialize components
|
||||
self.documents: Dict[str, Document] = {}
|
||||
self.encoder = SentenceTransformer(model_name)
|
||||
self.index = faiss.IndexFlatIP(embedding_dim) # Inner product index
|
||||
self.clustering = AgglomerativeClustering(
|
||||
n_clusters=num_clusters,
|
||||
metric='euclidean',
|
||||
linkage='ward'
|
||||
)
|
||||
|
||||
# Thread safety
|
||||
self._lock = threading.Lock()
|
||||
self._executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
logger.info("HQD-RAG system initialized successfully")
|
||||
|
||||
def _compute_quantum_embedding(self, text: str) -> np.ndarray:
|
||||
"""Compute quantum-inspired embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
|
||||
Returns:
|
||||
Quantum-inspired embedding vector
|
||||
"""
|
||||
# Get base embedding
|
||||
base_embedding = self.encoder.encode([text])[0]
|
||||
|
||||
# Apply quantum-inspired transformation
|
||||
# Simulate superposition by adding phase components
|
||||
phase = np.exp(2j * np.pi * np.random.random(self.embedding_dim))
|
||||
quantum_embedding = base_embedding * phase
|
||||
|
||||
# Normalize to unit length
|
||||
return quantum_embedding / np.linalg.norm(quantum_embedding)
|
||||
|
||||
def _verify_reliability(self, doc: Document, query_embedding: np.ndarray) -> float:
|
||||
"""Verify the reliability of a document match.
|
||||
|
||||
Args:
|
||||
doc: Document to verify
|
||||
query_embedding: Query embedding vector
|
||||
|
||||
Returns:
|
||||
Reliability score between 0 and 1
|
||||
"""
|
||||
if doc.embedding is None:
|
||||
return 0.0
|
||||
|
||||
# Compute consistency score
|
||||
consistency = np.abs(np.dot(doc.embedding, query_embedding))
|
||||
|
||||
# Add quantum noise resistance check
|
||||
noise = np.random.normal(0, 0.1, self.embedding_dim)
|
||||
noisy_query = query_embedding + noise
|
||||
noisy_query = noisy_query / np.linalg.norm(noisy_query)
|
||||
noise_resistance = np.abs(np.dot(doc.embedding, noisy_query))
|
||||
|
||||
return (consistency + noise_resistance) / 2
|
||||
|
||||
def add(self, content: str, doc_id: Optional[str] = None) -> str:
|
||||
"""Add a document to the system.
|
||||
|
||||
Args:
|
||||
content: Document text content
|
||||
doc_id: Optional custom document ID
|
||||
|
||||
Returns:
|
||||
Document ID
|
||||
"""
|
||||
doc_id = doc_id or str(uuid.uuid4())
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
# Compute embedding
|
||||
embedding = self._compute_quantum_embedding(content)
|
||||
|
||||
# Create document
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
content=content,
|
||||
embedding=embedding
|
||||
)
|
||||
|
||||
# Add to storage
|
||||
self.documents[doc_id] = doc
|
||||
self.index.add(embedding.reshape(1, -1))
|
||||
|
||||
# Update clustering
|
||||
self._update_clusters()
|
||||
|
||||
logger.info(f"Successfully added document {doc_id}")
|
||||
return doc_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding document: {str(e)}")
|
||||
raise
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
return_scores: bool = False
|
||||
) -> Union[List[str], List[tuple[str, float]]]:
|
||||
"""Query the system for relevant documents.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
k: Number of results to return
|
||||
return_scores: Whether to return similarity scores
|
||||
|
||||
Returns:
|
||||
List of document IDs or (document ID, score) tuples
|
||||
"""
|
||||
try:
|
||||
# Compute query embedding
|
||||
query_embedding = self._compute_quantum_embedding(query)
|
||||
|
||||
# Search index
|
||||
scores, indices = self.index.search(
|
||||
query_embedding.reshape(1, -1),
|
||||
k * 2 # Get extra results for reliability filtering
|
||||
)
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
# Get document
|
||||
doc_id = list(self.documents.keys())[idx]
|
||||
doc = self.documents[doc_id]
|
||||
|
||||
# Verify reliability
|
||||
reliability = self._verify_reliability(doc, query_embedding)
|
||||
|
||||
if reliability >= self.reliability_threshold:
|
||||
results.append((doc_id, float(score)))
|
||||
|
||||
if len(results) >= k:
|
||||
break
|
||||
|
||||
logger.info(f"Query returned {len(results)} results")
|
||||
|
||||
if return_scores:
|
||||
return results
|
||||
return [doc_id for doc_id, _ in results]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing query: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, doc_id: str, new_content: str) -> None:
|
||||
"""Update an existing document.
|
||||
|
||||
Args:
|
||||
doc_id: ID of document to update
|
||||
new_content: New document content
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
if doc_id not in self.documents:
|
||||
raise KeyError(f"Document {doc_id} not found")
|
||||
|
||||
# Remove old embedding
|
||||
old_doc = self.documents[doc_id]
|
||||
if old_doc.embedding is not None:
|
||||
self.index.remove_ids(np.array([list(self.documents.keys()).index(doc_id)]))
|
||||
|
||||
# Compute new embedding
|
||||
new_embedding = self._compute_quantum_embedding(new_content)
|
||||
|
||||
# Update document
|
||||
self.documents[doc_id] = Document(
|
||||
id=doc_id,
|
||||
content=new_content,
|
||||
embedding=new_embedding
|
||||
)
|
||||
|
||||
# Add new embedding
|
||||
self.index.add(new_embedding.reshape(1, -1))
|
||||
|
||||
# Update clustering
|
||||
self._update_clusters()
|
||||
|
||||
logger.info(f"Successfully updated document {doc_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Delete a document from the system.
|
||||
|
||||
Args:
|
||||
doc_id: ID of document to delete
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
if doc_id not in self.documents:
|
||||
raise KeyError(f"Document {doc_id} not found")
|
||||
|
||||
# Remove from index
|
||||
idx = list(self.documents.keys()).index(doc_id)
|
||||
self.index.remove_ids(np.array([idx]))
|
||||
|
||||
# Remove from storage
|
||||
del self.documents[doc_id]
|
||||
|
||||
# Update clustering
|
||||
self._update_clusters()
|
||||
|
||||
logger.info(f"Successfully deleted document {doc_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document: {str(e)}")
|
||||
raise
|
||||
|
||||
def _update_clusters(self) -> None:
|
||||
"""Update hierarchical document clusters."""
|
||||
if len(self.documents) < 2:
|
||||
return
|
||||
|
||||
# Get all embeddings
|
||||
embeddings = np.vstack([
|
||||
doc.embedding for doc in self.documents.values()
|
||||
if doc.embedding is not None
|
||||
])
|
||||
|
||||
# Update clustering
|
||||
clusters = self.clustering.fit_predict(embeddings)
|
||||
|
||||
# Assign cluster IDs
|
||||
for doc, cluster_id in zip(self.documents.values(), clusters):
|
||||
doc.cluster_id = int(cluster_id)
|
||||
|
||||
def save(self, path: Union[str, Path]) -> None:
|
||||
"""Save the system state to disk.
|
||||
|
||||
Args:
|
||||
path: Path to save directory
|
||||
"""
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Save documents
|
||||
with open(path / "documents.pkl", "wb") as f:
|
||||
pickle.dump(self.documents, f)
|
||||
|
||||
# Save index
|
||||
faiss.write_index(self.index, str(path / "index.faiss"))
|
||||
|
||||
logger.info(f"Successfully saved system state to {path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving system state: {str(e)}")
|
||||
raise
|
||||
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load the system state from disk.
|
||||
|
||||
Args:
|
||||
path: Path to save directory
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
try:
|
||||
# Load documents
|
||||
with open(path / "documents.pkl", "rb") as f:
|
||||
self.documents = pickle.load(f)
|
||||
|
||||
# Load index
|
||||
self.index = faiss.read_index(str(path / "index.faiss"))
|
||||
|
||||
logger.info(f"Successfully loaded system state from {path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading system state: {str(e)}")
|
||||
raise
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logger.add(
|
||||
"hqd_rag.log",
|
||||
rotation="1 day",
|
||||
retention="1 week",
|
||||
level="INFO"
|
||||
)
|
||||
|
||||
# Initialize system
|
||||
rag = HQDRAG()
|
||||
|
||||
# Add some documents
|
||||
doc_ids = []
|
||||
docs = [
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
"Machine learning is a subset of artificial intelligence",
|
||||
"Python is a popular programming language"
|
||||
]
|
||||
|
||||
for doc in docs:
|
||||
doc_id = rag.add(doc)
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
# Query
|
||||
results = rag.query("What is machine learning?", return_scores=True)
|
||||
print("Query results:", results)
|
||||
|
||||
# # Update a document
|
||||
# rag.update(doc_ids[0], "The fast brown fox jumps over the sleepy dog")
|
||||
|
||||
# # Delete a document
|
||||
# rag.delete(doc_ids[-1])
|
||||
|
||||
# # Save state
|
||||
# rag.save("hqd_rag_state")
|
||||
|
||||
|
||||
|
@ -0,0 +1,554 @@
|
||||
from typing import Dict, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm
|
||||
import asyncio
|
||||
import json
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
# Configure logging
|
||||
logger.add("forex_forest.log", rotation="500 MB", level="INFO")
|
||||
|
||||
|
||||
class ForexDataFeed:
|
||||
"""Real-time forex data collector using free open sources"""
|
||||
|
||||
def __init__(self):
|
||||
self.pairs = [
|
||||
"EUR/USD",
|
||||
"GBP/USD",
|
||||
"USD/JPY",
|
||||
"AUD/USD",
|
||||
"USD/CAD",
|
||||
]
|
||||
|
||||
async def fetch_ecb_rates(self) -> Dict:
|
||||
"""Fetch exchange rates from European Central Bank (no key required)"""
|
||||
try:
|
||||
url = "https://www.ecb.europa.eu/stats/eurofxref/eurofxref-daily.xml"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
xml_data = await response.text()
|
||||
|
||||
root = ET.fromstring(xml_data)
|
||||
rates = {}
|
||||
for cube in root.findall(".//*[@currency]"):
|
||||
currency = cube.get("currency")
|
||||
rate = float(cube.get("rate"))
|
||||
rates[currency] = rate
|
||||
|
||||
# Calculate cross rates
|
||||
rates["EUR"] = 1.0 # Base currency
|
||||
cross_rates = {}
|
||||
for pair in self.pairs:
|
||||
base, quote = pair.split("/")
|
||||
if base in rates and quote in rates:
|
||||
cross_rates[pair] = rates[base] / rates[quote]
|
||||
|
||||
return cross_rates
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching ECB rates: {e}")
|
||||
return {}
|
||||
|
||||
async def fetch_forex_factory_data(self) -> Dict:
|
||||
"""Scrape trading data from Forex Factory"""
|
||||
try:
|
||||
url = "https://www.forexfactory.com"
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers=headers
|
||||
) as response:
|
||||
text = await response.text()
|
||||
|
||||
soup = BeautifulSoup(text, "html.parser")
|
||||
|
||||
# Get calendar events
|
||||
calendar = []
|
||||
calendar_table = soup.find(
|
||||
"table", class_="calendar__table"
|
||||
)
|
||||
if calendar_table:
|
||||
for row in calendar_table.find_all(
|
||||
"tr", class_="calendar__row"
|
||||
):
|
||||
try:
|
||||
event = {
|
||||
"currency": row.find(
|
||||
"td", class_="calendar__currency"
|
||||
).text.strip(),
|
||||
"event": row.find(
|
||||
"td", class_="calendar__event"
|
||||
).text.strip(),
|
||||
"impact": row.find(
|
||||
"td", class_="calendar__impact"
|
||||
).text.strip(),
|
||||
"time": row.find(
|
||||
"td", class_="calendar__time"
|
||||
).text.strip(),
|
||||
}
|
||||
calendar.append(event)
|
||||
except:
|
||||
continue
|
||||
|
||||
return {"calendar": calendar}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Forex Factory data: {e}")
|
||||
return {}
|
||||
|
||||
async def fetch_tradingeconomics_data(self) -> Dict:
|
||||
"""Scrape economic data from Trading Economics"""
|
||||
try:
|
||||
url = "https://tradingeconomics.com/calendar"
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers=headers
|
||||
) as response:
|
||||
text = await response.text()
|
||||
|
||||
soup = BeautifulSoup(text, "html.parser")
|
||||
|
||||
# Get economic indicators
|
||||
indicators = []
|
||||
calendar_table = soup.find("table", class_="table")
|
||||
if calendar_table:
|
||||
for row in calendar_table.find_all("tr")[
|
||||
1:
|
||||
]: # Skip header
|
||||
try:
|
||||
cols = row.find_all("td")
|
||||
indicator = {
|
||||
"country": cols[0].text.strip(),
|
||||
"indicator": cols[1].text.strip(),
|
||||
"actual": cols[2].text.strip(),
|
||||
"previous": cols[3].text.strip(),
|
||||
"consensus": cols[4].text.strip(),
|
||||
}
|
||||
indicators.append(indicator)
|
||||
except:
|
||||
continue
|
||||
|
||||
return {"indicators": indicators}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error fetching Trading Economics data: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def fetch_dailyfx_data(self) -> Dict:
|
||||
"""Scrape market analysis from DailyFX"""
|
||||
try:
|
||||
url = "https://www.dailyfx.com/market-news"
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers=headers
|
||||
) as response:
|
||||
text = await response.text()
|
||||
|
||||
soup = BeautifulSoup(text, "html.parser")
|
||||
|
||||
# Get market news and analysis
|
||||
news = []
|
||||
articles = soup.find_all("article", class_="dfx-article")
|
||||
for article in articles[:10]: # Get latest 10 articles
|
||||
try:
|
||||
news_item = {
|
||||
"title": article.find("h3").text.strip(),
|
||||
"summary": article.find("p").text.strip(),
|
||||
"currency": article.get(
|
||||
"data-currency", "General"
|
||||
),
|
||||
"timestamp": article.find("time").get(
|
||||
"datetime"
|
||||
),
|
||||
}
|
||||
news.append(news_item)
|
||||
except:
|
||||
continue
|
||||
|
||||
return {"news": news}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching DailyFX data: {e}")
|
||||
return {}
|
||||
|
||||
async def fetch_all_data(self) -> Dict:
|
||||
"""Fetch and combine all forex data sources"""
|
||||
try:
|
||||
# Fetch data from all sources concurrently
|
||||
rates, ff_data, te_data, dx_data = await asyncio.gather(
|
||||
self.fetch_ecb_rates(),
|
||||
self.fetch_forex_factory_data(),
|
||||
self.fetch_tradingeconomics_data(),
|
||||
self.fetch_dailyfx_data(),
|
||||
)
|
||||
|
||||
# Combine all data
|
||||
market_data = {
|
||||
"exchange_rates": rates,
|
||||
"calendar": ff_data.get("calendar", []),
|
||||
"economic_indicators": te_data.get("indicators", []),
|
||||
"market_news": dx_data.get("news", []),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return market_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching all data: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Rest of the ForexForestSystem class remains the same...
|
||||
|
||||
# (Previous ForexDataFeed class code remains the same...)
|
||||
|
||||
# Specialized Agent Prompts
|
||||
TECHNICAL_ANALYST_PROMPT = """You are an expert forex technical analyst agent.
|
||||
Your responsibilities:
|
||||
1. Analyze real-time exchange rate data for patterns and trends
|
||||
2. Calculate cross-rates and currency correlations
|
||||
3. Generate trading signals based on price action
|
||||
4. Monitor market volatility and momentum
|
||||
5. Identify key support and resistance levels
|
||||
|
||||
Data Format:
|
||||
- You will receive exchange rates from ECB and calculated cross-rates
|
||||
- Focus on major currency pairs and their relationships
|
||||
- Consider market volatility and trading volumes
|
||||
|
||||
Output Format:
|
||||
{
|
||||
"analysis_type": "technical",
|
||||
"timestamp": "ISO timestamp",
|
||||
"signals": [
|
||||
{
|
||||
"pair": "Currency pair",
|
||||
"trend": "bullish/bearish/neutral",
|
||||
"strength": 1-10,
|
||||
"key_levels": {"support": [], "resistance": []},
|
||||
"recommendation": "buy/sell/hold"
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
FUNDAMENTAL_ANALYST_PROMPT = """You are an expert forex fundamental analyst agent.
|
||||
Your responsibilities:
|
||||
1. Analyze economic calendar events and their impact
|
||||
2. Evaluate economic indicators from Trading Economics
|
||||
3. Assess market news and sentiment from DailyFX
|
||||
4. Monitor central bank actions and policies
|
||||
5. Track geopolitical events affecting currencies
|
||||
|
||||
Data Format:
|
||||
- Economic calendar events with impact levels
|
||||
- Latest economic indicators and previous values
|
||||
- Market news and analysis from reliable sources
|
||||
- Central bank statements and policy changes
|
||||
|
||||
Output Format:
|
||||
{
|
||||
"analysis_type": "fundamental",
|
||||
"timestamp": "ISO timestamp",
|
||||
"assessments": [
|
||||
{
|
||||
"currency": "Currency code",
|
||||
"economic_outlook": "positive/negative/neutral",
|
||||
"key_events": [],
|
||||
"impact_score": 1-10,
|
||||
"bias": "bullish/bearish/neutral"
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
MARKET_SENTIMENT_PROMPT = """You are an expert market sentiment analysis agent.
|
||||
Your responsibilities:
|
||||
1. Analyze news sentiment from DailyFX articles
|
||||
2. Track market positioning and bias
|
||||
3. Monitor risk sentiment and market fear/greed
|
||||
4. Identify potential market drivers
|
||||
5. Detect sentiment shifts and extremes
|
||||
|
||||
Data Format:
|
||||
- Market news and analysis articles
|
||||
- Trading sentiment indicators
|
||||
- Risk event calendar
|
||||
- Market commentary and analysis
|
||||
|
||||
Output Format:
|
||||
{
|
||||
"analysis_type": "sentiment",
|
||||
"timestamp": "ISO timestamp",
|
||||
"sentiment_data": [
|
||||
{
|
||||
"pair": "Currency pair",
|
||||
"sentiment": "risk-on/risk-off",
|
||||
"strength": 1-10,
|
||||
"key_drivers": [],
|
||||
"outlook": "positive/negative/neutral"
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
STRATEGY_COORDINATOR_PROMPT = """You are the lead forex strategy coordination agent.
|
||||
Your responsibilities:
|
||||
1. Synthesize technical, fundamental, and sentiment analysis
|
||||
2. Generate final trading recommendations
|
||||
3. Manage risk exposure and position sizing
|
||||
4. Coordinate entry and exit points
|
||||
5. Monitor open positions and adjust strategies
|
||||
|
||||
Data Format:
|
||||
- Analysis from technical, fundamental, and sentiment agents
|
||||
- Current market rates and conditions
|
||||
- Economic calendar and news events
|
||||
- Risk parameters and exposure limits
|
||||
|
||||
Output Format:
|
||||
{
|
||||
"analysis_type": "strategy",
|
||||
"timestamp": "ISO timestamp",
|
||||
"recommendations": [
|
||||
{
|
||||
"pair": "Currency pair",
|
||||
"action": "buy/sell/hold",
|
||||
"confidence": 1-10,
|
||||
"entry_points": [],
|
||||
"stop_loss": float,
|
||||
"take_profit": float,
|
||||
"rationale": "string"
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
|
||||
class ForexForestSystem:
|
||||
"""Main system coordinating the forest swarm and data feeds"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the forex forest system"""
|
||||
self.data_feed = ForexDataFeed()
|
||||
|
||||
# Create Technical Analysis Tree
|
||||
technical_agents = [
|
||||
TreeAgent(
|
||||
system_prompt=TECHNICAL_ANALYST_PROMPT,
|
||||
agent_name="Price Action Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=TECHNICAL_ANALYST_PROMPT,
|
||||
agent_name="Cross Rate Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=TECHNICAL_ANALYST_PROMPT,
|
||||
agent_name="Volatility Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
]
|
||||
|
||||
# Create Fundamental Analysis Tree
|
||||
fundamental_agents = [
|
||||
TreeAgent(
|
||||
system_prompt=FUNDAMENTAL_ANALYST_PROMPT,
|
||||
agent_name="Economic Data Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=FUNDAMENTAL_ANALYST_PROMPT,
|
||||
agent_name="News Impact Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=FUNDAMENTAL_ANALYST_PROMPT,
|
||||
agent_name="Central Bank Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
]
|
||||
|
||||
# Create Sentiment Analysis Tree
|
||||
sentiment_agents = [
|
||||
TreeAgent(
|
||||
system_prompt=MARKET_SENTIMENT_PROMPT,
|
||||
agent_name="News Sentiment Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=MARKET_SENTIMENT_PROMPT,
|
||||
agent_name="Risk Sentiment Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=MARKET_SENTIMENT_PROMPT,
|
||||
agent_name="Market Positioning Analyst",
|
||||
model_name="gpt-4o",
|
||||
),
|
||||
]
|
||||
|
||||
# Create Strategy Coordination Tree
|
||||
strategy_agents = [
|
||||
TreeAgent(
|
||||
system_prompt=STRATEGY_COORDINATOR_PROMPT,
|
||||
agent_name="Lead Strategy Coordinator",
|
||||
model_name="gpt-4",
|
||||
temperature=0.5,
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=STRATEGY_COORDINATOR_PROMPT,
|
||||
agent_name="Risk Manager",
|
||||
model_name="gpt-4",
|
||||
temperature=0.5,
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt=STRATEGY_COORDINATOR_PROMPT,
|
||||
agent_name="Position Manager",
|
||||
model_name="gpt-4",
|
||||
temperature=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
# Create trees
|
||||
self.technical_tree = Tree(
|
||||
tree_name="Technical Analysis", agents=technical_agents
|
||||
)
|
||||
self.fundamental_tree = Tree(
|
||||
tree_name="Fundamental Analysis",
|
||||
agents=fundamental_agents,
|
||||
)
|
||||
self.sentiment_tree = Tree(
|
||||
tree_name="Sentiment Analysis", agents=sentiment_agents
|
||||
)
|
||||
self.strategy_tree = Tree(
|
||||
tree_name="Strategy Coordination", agents=strategy_agents
|
||||
)
|
||||
|
||||
# Create forest swarm
|
||||
self.forest = ForestSwarm(
|
||||
trees=[
|
||||
self.technical_tree,
|
||||
self.fundamental_tree,
|
||||
self.sentiment_tree,
|
||||
self.strategy_tree,
|
||||
]
|
||||
)
|
||||
|
||||
logger.info("Forex Forest System initialized successfully")
|
||||
|
||||
async def prepare_analysis_task(self) -> str:
|
||||
"""Prepare the analysis task with real-time data"""
|
||||
try:
|
||||
market_data = await self.data_feed.fetch_all_data()
|
||||
|
||||
task = {
|
||||
"action": "analyze_forex_markets",
|
||||
"market_data": market_data,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"analysis_required": [
|
||||
"technical",
|
||||
"fundamental",
|
||||
"sentiment",
|
||||
"strategy",
|
||||
],
|
||||
}
|
||||
|
||||
return json.dumps(task, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing analysis task: {e}")
|
||||
raise
|
||||
|
||||
async def run_analysis_cycle(self) -> Dict:
|
||||
"""Run a complete analysis cycle with the forest swarm"""
|
||||
try:
|
||||
# Prepare task with real-time data
|
||||
task = await self.prepare_analysis_task()
|
||||
|
||||
# Run forest swarm analysis
|
||||
result = self.forest.run(task)
|
||||
|
||||
# Parse and validate results
|
||||
analysis = (
|
||||
json.loads(result)
|
||||
if isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
logger.info("Analysis cycle completed successfully")
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis cycle: {e}")
|
||||
raise
|
||||
|
||||
async def monitor_markets(self, interval_seconds: int = 300):
|
||||
"""Continuously monitor markets and run analysis"""
|
||||
while True:
|
||||
try:
|
||||
# Run analysis cycle
|
||||
analysis = await self.run_analysis_cycle()
|
||||
|
||||
# Log results
|
||||
logger.info("Market analysis completed")
|
||||
logger.debug(
|
||||
f"Analysis results: {json.dumps(analysis, indent=2)}"
|
||||
)
|
||||
|
||||
# Process any trading signals
|
||||
if "recommendations" in analysis:
|
||||
await self.process_trading_signals(
|
||||
analysis["recommendations"]
|
||||
)
|
||||
|
||||
# Wait for next interval
|
||||
await asyncio.sleep(interval_seconds)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in market monitoring: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def process_trading_signals(
|
||||
self, recommendations: List[Dict]
|
||||
):
|
||||
"""Process and log trading signals from analysis"""
|
||||
try:
|
||||
for rec in recommendations:
|
||||
logger.info(
|
||||
f"Trading Signal: {rec['pair']} - {rec['action']}"
|
||||
)
|
||||
logger.info(f"Confidence: {rec['confidence']}/10")
|
||||
logger.info(f"Entry Points: {rec['entry_points']}")
|
||||
logger.info(f"Stop Loss: {rec['stop_loss']}")
|
||||
logger.info(f"Take Profit: {rec['take_profit']}")
|
||||
logger.info(f"Rationale: {rec['rationale']}")
|
||||
logger.info("-" * 50)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing trading signals: {e}")
|
||||
|
||||
|
||||
# Example usage
|
||||
async def main():
|
||||
"""Main function to run the Forex Forest System"""
|
||||
try:
|
||||
system = ForexForestSystem()
|
||||
await system.monitor_markets()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set up asyncio event loop and run the system
|
||||
asyncio.run(main())
|
@ -1,618 +0,0 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Dict
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
"""Configuration class for MoE Transformer model parameters."""
|
||||
|
||||
vocab_size: int = 50257
|
||||
hidden_size: int = 768
|
||||
num_attention_heads: int = 12
|
||||
num_expert_layers: int = 4
|
||||
num_experts: int = 8
|
||||
expert_capacity: int = 32
|
||||
max_position_embeddings: int = 1024
|
||||
dropout_prob: float = 0.1
|
||||
layer_norm_epsilon: float = 1e-5
|
||||
initializer_range: float = 0.02
|
||||
num_query_groups: int = 4 # For multi-query attention
|
||||
|
||||
|
||||
class ExpertLayer(nn.Module):
|
||||
"""Individual expert neural network."""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(
|
||||
config.hidden_size, 4 * config.hidden_size
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
4 * config.hidden_size, config.hidden_size
|
||||
)
|
||||
self.activation = nn.GELU()
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MixtureOfExperts(nn.Module):
|
||||
"""Mixture of Experts layer with dynamic routing."""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.expert_capacity = config.expert_capacity
|
||||
|
||||
# Create expert networks
|
||||
self.experts = nn.ModuleList(
|
||||
[ExpertLayer(config) for _ in range(config.num_experts)]
|
||||
)
|
||||
|
||||
# Router network
|
||||
self.router = nn.Linear(
|
||||
config.hidden_size, config.num_experts
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||
"""Route inputs to experts and combine outputs."""
|
||||
batch_size, seq_len, hidden_size = x.shape
|
||||
|
||||
# Calculate routing probabilities
|
||||
router_logits = self.router(x)
|
||||
routing_weights = F.softmax(router_logits, dim=-1)
|
||||
|
||||
# Select top-k experts
|
||||
top_k = 2
|
||||
gates, indices = torch.topk(routing_weights, top_k, dim=-1)
|
||||
gates = F.softmax(gates, dim=-1)
|
||||
|
||||
# Process inputs through selected experts
|
||||
final_output = torch.zeros_like(x)
|
||||
router_load = torch.zeros(self.num_experts, device=x.device)
|
||||
|
||||
for i in range(top_k):
|
||||
expert_index = indices[..., i]
|
||||
gate = gates[..., i : i + 1]
|
||||
|
||||
# Count expert assignments
|
||||
for j in range(self.num_experts):
|
||||
router_load[j] += (expert_index == j).float().sum()
|
||||
|
||||
# Process through selected experts
|
||||
for j in range(self.num_experts):
|
||||
mask = expert_index == j
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
expert_input = x[mask]
|
||||
expert_output = self.experts[j](expert_input)
|
||||
final_output[mask] += gate[mask] * expert_output
|
||||
|
||||
aux_loss = router_load.float().var() / (
|
||||
router_load.float().mean() ** 2
|
||||
)
|
||||
|
||||
return final_output, {"load_balancing_loss": aux_loss}
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
"""Multi-Query Attention mechanism with proper multi-query group handling."""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_query_groups = config.num_query_groups
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# Query projection maintains full head dimension
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.hidden_size
|
||||
)
|
||||
|
||||
# Key and value projections use reduced number of heads (query groups)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.head_dim * config.num_query_groups,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.head_dim * config.num_query_groups,
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
|
||||
# Calculate heads per group for proper reshaping
|
||||
self.heads_per_group = (
|
||||
self.num_attention_heads // self.num_query_groups
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
cache: Optional[Dict[str, Tensor]] = None,
|
||||
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
# Project queries, keys, and values
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape queries to full number of heads
|
||||
queries = queries.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_attention_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
# Reshape keys and values to number of query groups
|
||||
keys = keys.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_query_groups,
|
||||
self.head_dim,
|
||||
)
|
||||
values = values.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_query_groups,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
# Transpose for batch matrix multiplication
|
||||
queries = queries.transpose(
|
||||
1, 2
|
||||
) # (batch, n_heads, seq_len, head_dim)
|
||||
keys = keys.transpose(
|
||||
1, 2
|
||||
) # (batch, n_groups, seq_len, head_dim)
|
||||
values = values.transpose(
|
||||
1, 2
|
||||
) # (batch, n_groups, seq_len, head_dim)
|
||||
|
||||
# Repeat keys and values for each head in the group
|
||||
keys = keys.repeat_interleave(self.heads_per_group, dim=1)
|
||||
values = values.repeat_interleave(self.heads_per_group, dim=1)
|
||||
|
||||
# Compute attention scores
|
||||
scale = 1.0 / math.sqrt(self.head_dim)
|
||||
scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale
|
||||
|
||||
if attention_mask is not None:
|
||||
# Expand attention mask to match scores dimensions
|
||||
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
expanded_mask = expanded_mask.expand(
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
seq_length,
|
||||
seq_length,
|
||||
)
|
||||
mask_value = torch.finfo(scores.dtype).min
|
||||
attention_mask = expanded_mask.eq(0).float() * mask_value
|
||||
scores = scores + attention_mask
|
||||
|
||||
attention_weights = F.softmax(scores, dim=-1)
|
||||
attention_weights = self.dropout(attention_weights)
|
||||
|
||||
# Compute attention output
|
||||
attention_output = torch.matmul(attention_weights, values)
|
||||
attention_output = attention_output.transpose(1, 2)
|
||||
attention_output = attention_output.reshape(
|
||||
batch_size, seq_length, -1
|
||||
)
|
||||
|
||||
return attention_output, None
|
||||
|
||||
|
||||
class MoETransformer(nn.Module):
|
||||
"""
|
||||
Production-grade Transformer model with Mixture of Experts and Multi-Query Attention.
|
||||
|
||||
Features:
|
||||
- Multi-Query Attention mechanism for efficient inference
|
||||
- Mixture of Experts for dynamic routing and specialization
|
||||
- Real-time weight updates based on input similarity
|
||||
- Built-in logging and monitoring
|
||||
- Type annotations for better code maintainability
|
||||
"""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Initialize components
|
||||
self.embedding = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size
|
||||
)
|
||||
|
||||
# Multi-Query Attention layers
|
||||
self.attention_layers = nn.ModuleList(
|
||||
[
|
||||
MultiQueryAttention(config)
|
||||
for _ in range(config.num_expert_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Mixture of Experts layers
|
||||
self.moe_layers = nn.ModuleList(
|
||||
[
|
||||
MixtureOfExperts(config)
|
||||
for _ in range(config.num_expert_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Layer normalization and dropout
|
||||
self.layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
|
||||
# Output projection
|
||||
self.output_projection = nn.Linear(
|
||||
config.hidden_size, config.vocab_size
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
logger.info("Initialized MoETransformer model")
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize model weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=self.config.initializer_range
|
||||
)
|
||||
if (
|
||||
isinstance(module, nn.Linear)
|
||||
and module.bias is not None
|
||||
):
|
||||
module.bias.data.zero_()
|
||||
|
||||
def get_position_embeddings(self, position_ids: Tensor) -> Tensor:
|
||||
"""Generate position embeddings."""
|
||||
return self.position_embedding(position_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
cache: Optional[Dict[str, Tensor]] = None,
|
||||
) -> Tuple[Tensor, Dict]:
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
attention_mask: Attention mask for padding
|
||||
position_ids: Position IDs for positioning encoding
|
||||
cache: Cache for key/value states in generation
|
||||
|
||||
Returns:
|
||||
tuple: (logits, auxiliary_outputs)
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(
|
||||
input_ids
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
position_embeds = self.get_position_embeddings(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
# Initialize auxiliary outputs
|
||||
aux_outputs = {"moe_losses": []}
|
||||
|
||||
# Process through transformer layers
|
||||
for attention_layer, moe_layer in zip(
|
||||
self.attention_layers, self.moe_layers
|
||||
):
|
||||
# Multi-Query Attention
|
||||
attention_output, _ = attention_layer(
|
||||
hidden_states, attention_mask, cache
|
||||
)
|
||||
hidden_states = self.layer_norm(
|
||||
hidden_states + attention_output
|
||||
)
|
||||
|
||||
# Mixture of Experts
|
||||
moe_output, moe_aux = moe_layer(hidden_states)
|
||||
hidden_states = self.layer_norm(
|
||||
hidden_states + moe_output
|
||||
)
|
||||
aux_outputs["moe_losses"].append(
|
||||
moe_aux["load_balancing_loss"]
|
||||
)
|
||||
|
||||
# Final output projection
|
||||
logits = self.output_projection(hidden_states)
|
||||
|
||||
return logits, aux_outputs
|
||||
|
||||
def fetch_loss(
|
||||
self,
|
||||
logits: Tensor,
|
||||
labels: Tensor,
|
||||
aux_outputs: Dict,
|
||||
reduction: str = "mean",
|
||||
) -> Tensor:
|
||||
"""
|
||||
Calculate the total loss including MoE balancing losses.
|
||||
|
||||
Args:
|
||||
logits: Model output logits
|
||||
labels: Ground truth labels
|
||||
aux_outputs: Auxiliary outputs from forward pass
|
||||
reduction: Loss reduction method
|
||||
|
||||
Returns:
|
||||
Tensor: Total loss
|
||||
"""
|
||||
# Calculate cross entropy loss
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.view(-1, self.config.vocab_size),
|
||||
labels.view(-1),
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
# Calculate MoE loss
|
||||
moe_loss = torch.stack(aux_outputs["moe_losses"]).mean()
|
||||
|
||||
# Combine losses
|
||||
total_loss = ce_loss + 0.01 * moe_loss
|
||||
|
||||
logger.debug(
|
||||
f"CE Loss: {ce_loss.item():.4f}, "
|
||||
f"MoE Loss: {moe_loss.item():.4f}"
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
max_length: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 0.9,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Generate text using the model.
|
||||
|
||||
Args:
|
||||
input_ids: Initial input tokens
|
||||
max_length: Maximum sequence length to generate
|
||||
temperature: Sampling temperature
|
||||
top_k: Number of highest probability tokens to keep
|
||||
top_p: Cumulative probability for nucleus sampling
|
||||
|
||||
Returns:
|
||||
Tensor: Generated token IDs
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
|
||||
# Initialize sequence with input_ids
|
||||
generated = input_ids
|
||||
|
||||
# Cache for key-value pairs
|
||||
cache = {}
|
||||
|
||||
for _ in range(max_length):
|
||||
# Get position IDs for current sequence
|
||||
position_ids = torch.arange(
|
||||
generated.shape[1], dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand(
|
||||
batch_size, -1
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
logits, _ = self.forward(
|
||||
generated, position_ids=position_ids, cache=cache
|
||||
)
|
||||
|
||||
# Get next token logits
|
||||
next_token_logits = logits[:, -1, :] / temperature
|
||||
|
||||
# Apply top-k filtering
|
||||
if top_k > 0:
|
||||
indices_to_remove = (
|
||||
next_token_logits
|
||||
< torch.topk(next_token_logits, top_k)[0][
|
||||
..., -1, None
|
||||
]
|
||||
)
|
||||
next_token_logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Apply top-p (nucleus) filtering
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
next_token_logits, descending=True
|
||||
)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = (
|
||||
sorted_indices_to_remove[..., :-1].clone()
|
||||
)
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[
|
||||
sorted_indices_to_remove
|
||||
]
|
||||
next_token_logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Sample next token
|
||||
probs = F.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
# Append next token to sequence
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
|
||||
# Check for end of sequence token
|
||||
if (next_token == self.config.vocab_size - 1).all():
|
||||
break
|
||||
|
||||
return generated
|
||||
|
||||
|
||||
# Initialize model configuration
|
||||
config = TransformerConfig(
|
||||
vocab_size=50257,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
num_expert_layers=4,
|
||||
num_experts=8,
|
||||
expert_capacity=32,
|
||||
max_position_embeddings=1024,
|
||||
num_query_groups=4,
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_data(
|
||||
batch_size: int = 8,
|
||||
seq_length: int = 512,
|
||||
vocab_size: int = 50257,
|
||||
) -> DataLoader:
|
||||
"""Create sample data for demonstration."""
|
||||
# Create random input sequences
|
||||
input_ids = torch.randint(
|
||||
0, vocab_size, (100, seq_length) # 100 samples
|
||||
)
|
||||
|
||||
# Create target sequences (shifted by 1)
|
||||
labels = torch.randint(0, vocab_size, (100, seq_length))
|
||||
|
||||
# Create attention masks (1 for real tokens, 0 for padding)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = TensorDataset(input_ids, attention_mask, labels)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=True
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def train_step(
|
||||
model: MoETransformer,
|
||||
batch: tuple,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
) -> float:
|
||||
"""Execute single training step."""
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Unpack batch
|
||||
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
||||
|
||||
# Forward pass
|
||||
logits, aux_outputs = model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# Calculate loss
|
||||
loss = model.fetch_loss(logits, labels, aux_outputs)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
|
||||
def main():
|
||||
# Set device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Initialize model
|
||||
model = MoETransformer(config).to(device)
|
||||
logger.info("Model initialized")
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=1e-4, weight_decay=0.01
|
||||
)
|
||||
|
||||
# Prepare data
|
||||
dataloader = prepare_sample_data()
|
||||
logger.info("Data prepared")
|
||||
|
||||
# Training loop
|
||||
num_epochs = 3
|
||||
for epoch in range(num_epochs):
|
||||
epoch_losses = []
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
loss = train_step(model, batch, optimizer, device)
|
||||
epoch_losses.append(loss)
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{num_epochs} "
|
||||
f"Batch {batch_idx}/{len(dataloader)} "
|
||||
f"Loss: {loss:.4f}"
|
||||
)
|
||||
|
||||
avg_loss = np.mean(epoch_losses)
|
||||
logger.info(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
|
||||
|
||||
# Generation example
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# Prepare input prompt
|
||||
prompt = torch.randint(0, config.vocab_size, (1, 10)).to(
|
||||
device
|
||||
)
|
||||
|
||||
# Generate sequence
|
||||
generated = model.generate(
|
||||
input_ids=prompt,
|
||||
max_length=50,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
)
|
||||
|
||||
logger.info(f"Generated sequence shape: {generated.shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,333 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.v1 import validator
|
||||
from loguru import logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from swarm_models import OpenAIFunctionCaller, OpenAIChat
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.swarm_router import SwarmRouter
|
||||
from swarms.structs.agents_available import showcase_available_agents
|
||||
|
||||
|
||||
BOSS_SYSTEM_PROMPT = """
|
||||
Manage a swarm of worker agents to efficiently serve the user by deciding whether to create new agents or delegate tasks. Ensure operations are efficient and effective.
|
||||
|
||||
### Instructions:
|
||||
|
||||
1. **Task Assignment**:
|
||||
- Analyze available worker agents when a task is presented.
|
||||
- Delegate tasks to existing agents with clear, direct, and actionable instructions if an appropriate agent is available.
|
||||
- If no suitable agent exists, create a new agent with a fitting system prompt to handle the task.
|
||||
|
||||
2. **Agent Creation**:
|
||||
- Name agents according to the task they are intended to perform (e.g., "Twitter Marketing Agent").
|
||||
- Provide each new agent with a concise and clear system prompt that includes its role, objectives, and any tools it can utilize.
|
||||
|
||||
3. **Efficiency**:
|
||||
- Minimize redundancy and maximize task completion speed.
|
||||
- Avoid unnecessary agent creation if an existing agent can fulfill the task.
|
||||
|
||||
4. **Communication**:
|
||||
- Be explicit in task delegation instructions to avoid ambiguity and ensure effective task execution.
|
||||
- Require agents to report back on task completion or encountered issues.
|
||||
|
||||
5. **Reasoning and Decisions**:
|
||||
- Offer brief reasoning when selecting or creating agents to maintain transparency.
|
||||
- Avoid using an agent if unnecessary, with a clear explanation if no agents are suitable for a task.
|
||||
|
||||
# Output Format
|
||||
|
||||
Present your plan in clear, bullet-point format or short concise paragraphs, outlining task assignment, agent creation, efficiency strategies, and communication protocols.
|
||||
|
||||
# Notes
|
||||
|
||||
- Preserve transparency by always providing reasoning for task-agent assignments and creation.
|
||||
- Ensure instructions to agents are unambiguous to minimize error.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for an individual agent in a swarm"""
|
||||
|
||||
name: str = Field(
|
||||
description="The name of the agent", example="Research-Agent"
|
||||
)
|
||||
description: str = Field(
|
||||
description="A description of the agent's purpose and capabilities",
|
||||
example="Agent responsible for researching and gathering information",
|
||||
)
|
||||
system_prompt: str = Field(
|
||||
description="The system prompt that defines the agent's behavior",
|
||||
example="You are a research agent. Your role is to gather and analyze information...",
|
||||
)
|
||||
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if not v.strip():
|
||||
raise ValueError("Agent name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@validator("system_prompt")
|
||||
def validate_system_prompt(cls, v):
|
||||
if not v.strip():
|
||||
raise ValueError("System prompt cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class SwarmConfig(BaseModel):
|
||||
"""Configuration for a swarm of cooperative agents"""
|
||||
|
||||
name: str = Field(
|
||||
description="The name of the swarm",
|
||||
example="Research-Writing-Swarm",
|
||||
)
|
||||
description: str = Field(
|
||||
description="The description of the swarm's purpose and capabilities",
|
||||
example="A swarm of agents that work together to research topics and write articles",
|
||||
)
|
||||
agents: List[AgentConfig] = Field(
|
||||
description="The list of agents that make up the swarm",
|
||||
min_items=1,
|
||||
)
|
||||
|
||||
@validator("agents")
|
||||
def validate_agents(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Swarm must have at least one agent")
|
||||
return v
|
||||
|
||||
|
||||
class AutoSwarmBuilder:
|
||||
"""A class that automatically builds and manages swarms of AI agents with enhanced error handling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-4",
|
||||
):
|
||||
self.name = name or "DefaultSwarm"
|
||||
self.description = description or "Generic AI Agent Swarm"
|
||||
self.verbose = verbose
|
||||
self.agents_pool = []
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.model_name = model_name
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key must be provided either through initialization or environment variable"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initialized AutoSwarmBuilder",
|
||||
extra={
|
||||
"swarm_name": self.name,
|
||||
"description": self.description,
|
||||
"model": self.model_name,
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize OpenAI chat model
|
||||
try:
|
||||
self.chat_model = OpenAIChat(
|
||||
openai_api_key=self.api_key,
|
||||
model_name=self.model_name,
|
||||
temperature=0.1,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to initialize OpenAI chat model: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
def run(self, task: str, image_url: Optional[str] = None) -> str:
|
||||
"""Run the swarm on a given task with error handling and retries."""
|
||||
if not task or not task.strip():
|
||||
raise ValueError("Task cannot be empty")
|
||||
|
||||
logger.info("Starting swarm execution", extra={"task": task})
|
||||
|
||||
try:
|
||||
# Create agents for the task
|
||||
agents = self._create_agents(task, image_url)
|
||||
if not agents:
|
||||
raise ValueError(
|
||||
"No agents were created for the task"
|
||||
)
|
||||
|
||||
# Execute the task through the swarm router
|
||||
logger.info(
|
||||
"Routing task through swarm",
|
||||
extra={"num_agents": len(agents)},
|
||||
)
|
||||
output = self.swarm_router(agents, task, image_url)
|
||||
|
||||
logger.info("Swarm execution completed successfully")
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during swarm execution: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
def _create_agents(
|
||||
self, task: str, image_url: Optional[str] = None
|
||||
) -> List[Agent]:
|
||||
"""Create the necessary agents for a task with enhanced error handling."""
|
||||
logger.info("Creating agents for task", extra={"task": task})
|
||||
|
||||
try:
|
||||
model = OpenAIFunctionCaller(
|
||||
system_prompt=BOSS_SYSTEM_PROMPT,
|
||||
api_key=self.api_key,
|
||||
temperature=0.1,
|
||||
base_model=SwarmConfig,
|
||||
)
|
||||
|
||||
agents_config = model.run(task)
|
||||
print(f"{agents_config}")
|
||||
|
||||
if isinstance(agents_config, dict):
|
||||
agents_config = SwarmConfig(**agents_config)
|
||||
|
||||
# Update swarm configuration
|
||||
self.name = agents_config.name
|
||||
self.description = agents_config.description
|
||||
|
||||
# Create agents from configuration
|
||||
agents = []
|
||||
for agent_config in agents_config.agents:
|
||||
if isinstance(agent_config, dict):
|
||||
agent_config = AgentConfig(**agent_config)
|
||||
|
||||
agent = self.build_agent(
|
||||
agent_name=agent_config.name,
|
||||
agent_description=agent_config.description,
|
||||
agent_system_prompt=agent_config.system_prompt,
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
# Add available agents showcase to system prompts
|
||||
agents_available = showcase_available_agents(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
agents=agents,
|
||||
)
|
||||
|
||||
for agent in agents:
|
||||
agent.system_prompt += "\n" + agents_available
|
||||
|
||||
logger.info(
|
||||
"Successfully created agents",
|
||||
extra={"num_agents": len(agents)},
|
||||
)
|
||||
return agents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error creating agents: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def build_agent(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_description: str,
|
||||
agent_system_prompt: str,
|
||||
) -> Agent:
|
||||
"""Build a single agent with enhanced error handling."""
|
||||
logger.info(
|
||||
"Building agent", extra={"agent_name": agent_name}
|
||||
)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
system_prompt=agent_system_prompt,
|
||||
llm=self.chat_model,
|
||||
autosave=True,
|
||||
dashboard=False,
|
||||
verbose=self.verbose,
|
||||
dynamic_temperature_enabled=True,
|
||||
saved_state_path=f"states/{agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
||||
user_name="swarms_corp",
|
||||
retry_attempts=3,
|
||||
context_length=200000,
|
||||
return_step_meta=False,
|
||||
output_type="str",
|
||||
streaming_on=False,
|
||||
auto_generate_prompt=True,
|
||||
)
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error building agent: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
def swarm_router(
|
||||
self,
|
||||
agents: List[Agent],
|
||||
task: str,
|
||||
image_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Route tasks between agents in the swarm with error handling and retries."""
|
||||
logger.info(
|
||||
"Initializing swarm router",
|
||||
extra={"num_agents": len(agents)},
|
||||
)
|
||||
|
||||
try:
|
||||
swarm_router_instance = SwarmRouter(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
agents=agents,
|
||||
swarm_type="auto",
|
||||
)
|
||||
|
||||
formatted_task = f"{self.name} {self.description} {task}"
|
||||
result = swarm_router_instance.run(formatted_task)
|
||||
|
||||
logger.info("Successfully completed swarm routing")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in swarm router: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
swarm = AutoSwarmBuilder(
|
||||
name="ChipDesign-Swarm",
|
||||
description="A swarm of specialized AI agents for chip design",
|
||||
api_key="your-api-key", # Optional if set in environment
|
||||
model_name="gpt-4", # Optional, defaults to gpt-4
|
||||
)
|
||||
|
||||
result = swarm.run(
|
||||
"Design a new AI accelerator chip optimized for transformer model inference..."
|
||||
)
|
@ -1,165 +0,0 @@
|
||||
|
||||
from swarms import Agent, SwarmRouter
|
||||
|
||||
# Portfolio Analysis Specialist
|
||||
portfolio_analyzer = Agent(
|
||||
agent_name="Portfolio-Analysis-Specialist",
|
||||
system_prompt="""You are an expert portfolio analyst specializing in fund analysis and selection. Your core competencies include:
|
||||
- Comprehensive analysis of mutual funds, ETFs, and index funds
|
||||
- Evaluation of fund performance metrics (expense ratios, tracking error, Sharpe ratio)
|
||||
- Assessment of fund composition and strategy alignment
|
||||
- Risk-adjusted return analysis
|
||||
- Tax efficiency considerations
|
||||
|
||||
For each portfolio analysis:
|
||||
1. Evaluate fund characteristics and performance metrics
|
||||
2. Analyze expense ratios and fee structures
|
||||
3. Assess historical performance and volatility
|
||||
4. Compare funds within same category
|
||||
5. Consider tax implications
|
||||
6. Review fund manager track record and strategy consistency
|
||||
|
||||
Maintain focus on cost-efficiency and alignment with investment objectives.""",
|
||||
model_name="gpt-4o",
|
||||
max_loops=1,
|
||||
saved_state_path="portfolio_analyzer.json",
|
||||
user_name="investment_team",
|
||||
retry_attempts=2,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
)
|
||||
|
||||
# Asset Allocation Strategist
|
||||
allocation_strategist = Agent(
|
||||
agent_name="Asset-Allocation-Strategist",
|
||||
system_prompt="""You are a specialized asset allocation strategist focused on portfolio construction and optimization. Your expertise includes:
|
||||
- Strategic and tactical asset allocation
|
||||
- Risk tolerance assessment and portfolio matching
|
||||
- Geographic and sector diversification
|
||||
- Rebalancing strategy development
|
||||
- Portfolio optimization using modern portfolio theory
|
||||
|
||||
For each allocation:
|
||||
1. Analyze investor risk tolerance and objectives
|
||||
2. Develop appropriate asset class weights
|
||||
3. Select optimal fund combinations
|
||||
4. Design rebalancing triggers and schedules
|
||||
5. Consider tax-efficient fund placement
|
||||
6. Account for correlation between assets
|
||||
|
||||
Focus on creating well-diversified portfolios aligned with client goals and risk tolerance.""",
|
||||
model_name="gpt-4o",
|
||||
max_loops=1,
|
||||
saved_state_path="allocation_strategist.json",
|
||||
user_name="investment_team",
|
||||
retry_attempts=2,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
)
|
||||
|
||||
# Risk Management Specialist
|
||||
risk_manager = Agent(
|
||||
agent_name="Risk-Management-Specialist",
|
||||
system_prompt="""You are a risk management specialist focused on portfolio risk assessment and mitigation. Your expertise covers:
|
||||
- Portfolio risk metrics analysis
|
||||
- Downside protection strategies
|
||||
- Correlation analysis between funds
|
||||
- Stress testing and scenario analysis
|
||||
- Market condition impact assessment
|
||||
|
||||
For each portfolio:
|
||||
1. Calculate key risk metrics (Beta, Standard Deviation, etc.)
|
||||
2. Analyze correlation matrices
|
||||
3. Perform stress tests under various scenarios
|
||||
4. Evaluate liquidity risks
|
||||
5. Assess concentration risks
|
||||
6. Monitor factor exposures
|
||||
|
||||
Focus on maintaining appropriate risk levels while maximizing risk-adjusted returns.""",
|
||||
model_name="gpt-4o",
|
||||
max_loops=1,
|
||||
saved_state_path="risk_manager.json",
|
||||
user_name="investment_team",
|
||||
retry_attempts=2,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
)
|
||||
|
||||
# Portfolio Implementation Specialist
|
||||
implementation_specialist = Agent(
|
||||
agent_name="Portfolio-Implementation-Specialist",
|
||||
system_prompt="""You are a portfolio implementation specialist focused on efficient execution and maintenance. Your responsibilities include:
|
||||
- Fund selection for specific asset class exposure
|
||||
- Tax-efficient implementation strategies
|
||||
- Portfolio rebalancing execution
|
||||
- Trading cost analysis
|
||||
- Cash flow management
|
||||
|
||||
For each implementation:
|
||||
1. Select most efficient funds for desired exposure
|
||||
2. Plan tax-efficient transitions
|
||||
3. Design rebalancing schedule
|
||||
4. Optimize trade execution
|
||||
5. Manage cash positions
|
||||
6. Monitor tracking error
|
||||
|
||||
Maintain focus on minimizing costs and maximizing tax efficiency during implementation.""",
|
||||
model_name="gpt-4o",
|
||||
max_loops=1,
|
||||
saved_state_path="implementation_specialist.json",
|
||||
user_name="investment_team",
|
||||
retry_attempts=2,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
)
|
||||
|
||||
# Portfolio Monitoring Specialist
|
||||
monitoring_specialist = Agent(
|
||||
agent_name="Portfolio-Monitoring-Specialist",
|
||||
system_prompt="""You are a portfolio monitoring specialist focused on ongoing portfolio oversight and optimization. Your expertise includes:
|
||||
- Regular portfolio performance review
|
||||
- Drift monitoring and rebalancing triggers
|
||||
- Fund changes and replacements
|
||||
- Tax loss harvesting opportunities
|
||||
- Performance attribution analysis
|
||||
|
||||
For each review:
|
||||
1. Track portfolio drift from targets
|
||||
2. Monitor fund performance and changes
|
||||
3. Identify tax loss harvesting opportunities
|
||||
4. Analyze tracking error and expenses
|
||||
5. Review risk metrics evolution
|
||||
6. Generate performance attribution reports
|
||||
|
||||
Ensure continuous alignment with investment objectives while maintaining optimal portfolio efficiency.""",
|
||||
model_name="gpt-4o",
|
||||
max_loops=1,
|
||||
saved_state_path="monitoring_specialist.json",
|
||||
user_name="investment_team",
|
||||
retry_attempts=2,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
)
|
||||
|
||||
# List of all agents for portfolio management
|
||||
portfolio_agents = [
|
||||
portfolio_analyzer,
|
||||
allocation_strategist,
|
||||
risk_manager,
|
||||
implementation_specialist,
|
||||
monitoring_specialist
|
||||
]
|
||||
|
||||
|
||||
# Router
|
||||
router = SwarmRouter(
|
||||
name = "etf-portfolio-management-swarm",
|
||||
description="Creates and suggests an optimal portfolio",
|
||||
agents = portfolio_agents,
|
||||
swarm_type="SequentialWorkflow", # ConcurrentWorkflow
|
||||
max_loops = 1,
|
||||
)
|
||||
|
||||
router.run(
|
||||
task = "I have 10,000$ and I want to create a porfolio based on energy, ai, and datacenter companies. high growth."
|
||||
)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue