parent
a673ba2d71
commit
bb7e18f654
@ -1,387 +0,0 @@
|
|||||||
from typing import List, Dict, Optional, Union, Any
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
from scipy.sparse import csr_matrix
|
|
||||||
from sklearn.cluster import AgglomerativeClustering
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import faiss
|
|
||||||
import pickle
|
|
||||||
import time
|
|
||||||
from loguru import logger
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import threading
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Document:
|
|
||||||
"""Represents a document in the HQD-RAG system.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id (str): Unique identifier for the document
|
|
||||||
content (str): Raw text content of the document
|
|
||||||
embedding (Optional[np.ndarray]): Quantum-inspired embedding vector
|
|
||||||
cluster_id (Optional[int]): ID of the cluster this document belongs to
|
|
||||||
"""
|
|
||||||
id: str
|
|
||||||
content: str
|
|
||||||
embedding: Optional[np.ndarray] = None
|
|
||||||
cluster_id: Optional[int] = None
|
|
||||||
|
|
||||||
class HQDRAG:
|
|
||||||
"""
|
|
||||||
Hierarchical Quantum-Inspired Distributed RAG (HQD-RAG) System
|
|
||||||
|
|
||||||
A production-grade implementation of the HQD-RAG algorithm for ultra-fast
|
|
||||||
and reliable document retrieval. Uses quantum-inspired embeddings and
|
|
||||||
hierarchical clustering for efficient search.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
embedding_dim (int): Dimension of the quantum-inspired embeddings
|
|
||||||
num_clusters (int): Number of hierarchical clusters
|
|
||||||
similarity_threshold (float): Threshold for quantum similarity matching
|
|
||||||
reliability_threshold (float): Threshold for reliability verification
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_dim: int = 768,
|
|
||||||
num_clusters: int = 128,
|
|
||||||
similarity_threshold: float = 0.75,
|
|
||||||
reliability_threshold: float = 0.85,
|
|
||||||
model_name: str = "all-MiniLM-L6-v2"
|
|
||||||
):
|
|
||||||
"""Initialize the HQD-RAG system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding_dim: Dimension of document embeddings
|
|
||||||
num_clusters: Number of clusters for hierarchical organization
|
|
||||||
similarity_threshold: Minimum similarity score for retrieval
|
|
||||||
reliability_threshold: Minimum reliability score for verification
|
|
||||||
model_name: Name of the sentence transformer model to use
|
|
||||||
"""
|
|
||||||
logger.info(f"Initializing HQD-RAG with {embedding_dim} dimensions")
|
|
||||||
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
self.num_clusters = num_clusters
|
|
||||||
self.similarity_threshold = similarity_threshold
|
|
||||||
self.reliability_threshold = reliability_threshold
|
|
||||||
|
|
||||||
# Initialize components
|
|
||||||
self.documents: Dict[str, Document] = {}
|
|
||||||
self.encoder = SentenceTransformer(model_name)
|
|
||||||
self.index = faiss.IndexFlatIP(embedding_dim) # Inner product index
|
|
||||||
self.clustering = AgglomerativeClustering(
|
|
||||||
n_clusters=num_clusters,
|
|
||||||
metric='euclidean',
|
|
||||||
linkage='ward'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Thread safety
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
logger.info("HQD-RAG system initialized successfully")
|
|
||||||
|
|
||||||
def _compute_quantum_embedding(self, text: str) -> np.ndarray:
|
|
||||||
"""Compute quantum-inspired embedding for text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Quantum-inspired embedding vector
|
|
||||||
"""
|
|
||||||
# Get base embedding
|
|
||||||
base_embedding = self.encoder.encode([text])[0]
|
|
||||||
|
|
||||||
# Apply quantum-inspired transformation
|
|
||||||
# Simulate superposition by adding phase components
|
|
||||||
phase = np.exp(2j * np.pi * np.random.random(self.embedding_dim))
|
|
||||||
quantum_embedding = base_embedding * phase
|
|
||||||
|
|
||||||
# Normalize to unit length
|
|
||||||
return quantum_embedding / np.linalg.norm(quantum_embedding)
|
|
||||||
|
|
||||||
def _verify_reliability(self, doc: Document, query_embedding: np.ndarray) -> float:
|
|
||||||
"""Verify the reliability of a document match.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc: Document to verify
|
|
||||||
query_embedding: Query embedding vector
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reliability score between 0 and 1
|
|
||||||
"""
|
|
||||||
if doc.embedding is None:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# Compute consistency score
|
|
||||||
consistency = np.abs(np.dot(doc.embedding, query_embedding))
|
|
||||||
|
|
||||||
# Add quantum noise resistance check
|
|
||||||
noise = np.random.normal(0, 0.1, self.embedding_dim)
|
|
||||||
noisy_query = query_embedding + noise
|
|
||||||
noisy_query = noisy_query / np.linalg.norm(noisy_query)
|
|
||||||
noise_resistance = np.abs(np.dot(doc.embedding, noisy_query))
|
|
||||||
|
|
||||||
return (consistency + noise_resistance) / 2
|
|
||||||
|
|
||||||
def add(self, content: str, doc_id: Optional[str] = None) -> str:
|
|
||||||
"""Add a document to the system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Document text content
|
|
||||||
doc_id: Optional custom document ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Document ID
|
|
||||||
"""
|
|
||||||
doc_id = doc_id or str(uuid.uuid4())
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
# Compute embedding
|
|
||||||
embedding = self._compute_quantum_embedding(content)
|
|
||||||
|
|
||||||
# Create document
|
|
||||||
doc = Document(
|
|
||||||
id=doc_id,
|
|
||||||
content=content,
|
|
||||||
embedding=embedding
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add to storage
|
|
||||||
self.documents[doc_id] = doc
|
|
||||||
self.index.add(embedding.reshape(1, -1))
|
|
||||||
|
|
||||||
# Update clustering
|
|
||||||
self._update_clusters()
|
|
||||||
|
|
||||||
logger.info(f"Successfully added document {doc_id}")
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error adding document: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 5,
|
|
||||||
return_scores: bool = False
|
|
||||||
) -> Union[List[str], List[tuple[str, float]]]:
|
|
||||||
"""Query the system for relevant documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Query text
|
|
||||||
k: Number of results to return
|
|
||||||
return_scores: Whether to return similarity scores
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of document IDs or (document ID, score) tuples
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Compute query embedding
|
|
||||||
query_embedding = self._compute_quantum_embedding(query)
|
|
||||||
|
|
||||||
# Search index
|
|
||||||
scores, indices = self.index.search(
|
|
||||||
query_embedding.reshape(1, -1),
|
|
||||||
k * 2 # Get extra results for reliability filtering
|
|
||||||
)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for score, idx in zip(scores[0], indices[0]):
|
|
||||||
# Get document
|
|
||||||
doc_id = list(self.documents.keys())[idx]
|
|
||||||
doc = self.documents[doc_id]
|
|
||||||
|
|
||||||
# Verify reliability
|
|
||||||
reliability = self._verify_reliability(doc, query_embedding)
|
|
||||||
|
|
||||||
if reliability >= self.reliability_threshold:
|
|
||||||
results.append((doc_id, float(score)))
|
|
||||||
|
|
||||||
if len(results) >= k:
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info(f"Query returned {len(results)} results")
|
|
||||||
|
|
||||||
if return_scores:
|
|
||||||
return results
|
|
||||||
return [doc_id for doc_id, _ in results]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing query: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def update(self, doc_id: str, new_content: str) -> None:
|
|
||||||
"""Update an existing document.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID of document to update
|
|
||||||
new_content: New document content
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
if doc_id not in self.documents:
|
|
||||||
raise KeyError(f"Document {doc_id} not found")
|
|
||||||
|
|
||||||
# Remove old embedding
|
|
||||||
old_doc = self.documents[doc_id]
|
|
||||||
if old_doc.embedding is not None:
|
|
||||||
self.index.remove_ids(np.array([list(self.documents.keys()).index(doc_id)]))
|
|
||||||
|
|
||||||
# Compute new embedding
|
|
||||||
new_embedding = self._compute_quantum_embedding(new_content)
|
|
||||||
|
|
||||||
# Update document
|
|
||||||
self.documents[doc_id] = Document(
|
|
||||||
id=doc_id,
|
|
||||||
content=new_content,
|
|
||||||
embedding=new_embedding
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add new embedding
|
|
||||||
self.index.add(new_embedding.reshape(1, -1))
|
|
||||||
|
|
||||||
# Update clustering
|
|
||||||
self._update_clusters()
|
|
||||||
|
|
||||||
logger.info(f"Successfully updated document {doc_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating document: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def delete(self, doc_id: str) -> None:
|
|
||||||
"""Delete a document from the system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID of document to delete
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
if doc_id not in self.documents:
|
|
||||||
raise KeyError(f"Document {doc_id} not found")
|
|
||||||
|
|
||||||
# Remove from index
|
|
||||||
idx = list(self.documents.keys()).index(doc_id)
|
|
||||||
self.index.remove_ids(np.array([idx]))
|
|
||||||
|
|
||||||
# Remove from storage
|
|
||||||
del self.documents[doc_id]
|
|
||||||
|
|
||||||
# Update clustering
|
|
||||||
self._update_clusters()
|
|
||||||
|
|
||||||
logger.info(f"Successfully deleted document {doc_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting document: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _update_clusters(self) -> None:
|
|
||||||
"""Update hierarchical document clusters."""
|
|
||||||
if len(self.documents) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get all embeddings
|
|
||||||
embeddings = np.vstack([
|
|
||||||
doc.embedding for doc in self.documents.values()
|
|
||||||
if doc.embedding is not None
|
|
||||||
])
|
|
||||||
|
|
||||||
# Update clustering
|
|
||||||
clusters = self.clustering.fit_predict(embeddings)
|
|
||||||
|
|
||||||
# Assign cluster IDs
|
|
||||||
for doc, cluster_id in zip(self.documents.values(), clusters):
|
|
||||||
doc.cluster_id = int(cluster_id)
|
|
||||||
|
|
||||||
def save(self, path: Union[str, Path]) -> None:
|
|
||||||
"""Save the system state to disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to save directory
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Save documents
|
|
||||||
with open(path / "documents.pkl", "wb") as f:
|
|
||||||
pickle.dump(self.documents, f)
|
|
||||||
|
|
||||||
# Save index
|
|
||||||
faiss.write_index(self.index, str(path / "index.faiss"))
|
|
||||||
|
|
||||||
logger.info(f"Successfully saved system state to {path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving system state: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def load(self, path: Union[str, Path]) -> None:
|
|
||||||
"""Load the system state from disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to save directory
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load documents
|
|
||||||
with open(path / "documents.pkl", "rb") as f:
|
|
||||||
self.documents = pickle.load(f)
|
|
||||||
|
|
||||||
# Load index
|
|
||||||
self.index = faiss.read_index(str(path / "index.faiss"))
|
|
||||||
|
|
||||||
logger.info(f"Successfully loaded system state from {path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading system state: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Configure logging
|
|
||||||
logger.add(
|
|
||||||
"hqd_rag.log",
|
|
||||||
rotation="1 day",
|
|
||||||
retention="1 week",
|
|
||||||
level="INFO"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize system
|
|
||||||
rag = HQDRAG()
|
|
||||||
|
|
||||||
# Add some documents
|
|
||||||
doc_ids = []
|
|
||||||
docs = [
|
|
||||||
"The quick brown fox jumps over the lazy dog",
|
|
||||||
"Machine learning is a subset of artificial intelligence",
|
|
||||||
"Python is a popular programming language"
|
|
||||||
]
|
|
||||||
|
|
||||||
for doc in docs:
|
|
||||||
doc_id = rag.add(doc)
|
|
||||||
doc_ids.append(doc_id)
|
|
||||||
|
|
||||||
# Query
|
|
||||||
results = rag.query("What is machine learning?", return_scores=True)
|
|
||||||
print("Query results:", results)
|
|
||||||
|
|
||||||
# # Update a document
|
|
||||||
# rag.update(doc_ids[0], "The fast brown fox jumps over the sleepy dog")
|
|
||||||
|
|
||||||
# # Delete a document
|
|
||||||
# rag.delete(doc_ids[-1])
|
|
||||||
|
|
||||||
# # Save state
|
|
||||||
# rag.save("hqd_rag_state")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,554 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
from datetime import datetime
|
||||||
|
from loguru import logger
|
||||||
|
from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import aiohttp
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logger.add("forex_forest.log", rotation="500 MB", level="INFO")
|
||||||
|
|
||||||
|
|
||||||
|
class ForexDataFeed:
|
||||||
|
"""Real-time forex data collector using free open sources"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pairs = [
|
||||||
|
"EUR/USD",
|
||||||
|
"GBP/USD",
|
||||||
|
"USD/JPY",
|
||||||
|
"AUD/USD",
|
||||||
|
"USD/CAD",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def fetch_ecb_rates(self) -> Dict:
|
||||||
|
"""Fetch exchange rates from European Central Bank (no key required)"""
|
||||||
|
try:
|
||||||
|
url = "https://www.ecb.europa.eu/stats/eurofxref/eurofxref-daily.xml"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
xml_data = await response.text()
|
||||||
|
|
||||||
|
root = ET.fromstring(xml_data)
|
||||||
|
rates = {}
|
||||||
|
for cube in root.findall(".//*[@currency]"):
|
||||||
|
currency = cube.get("currency")
|
||||||
|
rate = float(cube.get("rate"))
|
||||||
|
rates[currency] = rate
|
||||||
|
|
||||||
|
# Calculate cross rates
|
||||||
|
rates["EUR"] = 1.0 # Base currency
|
||||||
|
cross_rates = {}
|
||||||
|
for pair in self.pairs:
|
||||||
|
base, quote = pair.split("/")
|
||||||
|
if base in rates and quote in rates:
|
||||||
|
cross_rates[pair] = rates[base] / rates[quote]
|
||||||
|
|
||||||
|
return cross_rates
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching ECB rates: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def fetch_forex_factory_data(self) -> Dict:
|
||||||
|
"""Scrape trading data from Forex Factory"""
|
||||||
|
try:
|
||||||
|
url = "https://www.forexfactory.com"
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url, headers=headers
|
||||||
|
) as response:
|
||||||
|
text = await response.text()
|
||||||
|
|
||||||
|
soup = BeautifulSoup(text, "html.parser")
|
||||||
|
|
||||||
|
# Get calendar events
|
||||||
|
calendar = []
|
||||||
|
calendar_table = soup.find(
|
||||||
|
"table", class_="calendar__table"
|
||||||
|
)
|
||||||
|
if calendar_table:
|
||||||
|
for row in calendar_table.find_all(
|
||||||
|
"tr", class_="calendar__row"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
event = {
|
||||||
|
"currency": row.find(
|
||||||
|
"td", class_="calendar__currency"
|
||||||
|
).text.strip(),
|
||||||
|
"event": row.find(
|
||||||
|
"td", class_="calendar__event"
|
||||||
|
).text.strip(),
|
||||||
|
"impact": row.find(
|
||||||
|
"td", class_="calendar__impact"
|
||||||
|
).text.strip(),
|
||||||
|
"time": row.find(
|
||||||
|
"td", class_="calendar__time"
|
||||||
|
).text.strip(),
|
||||||
|
}
|
||||||
|
calendar.append(event)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {"calendar": calendar}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching Forex Factory data: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def fetch_tradingeconomics_data(self) -> Dict:
|
||||||
|
"""Scrape economic data from Trading Economics"""
|
||||||
|
try:
|
||||||
|
url = "https://tradingeconomics.com/calendar"
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url, headers=headers
|
||||||
|
) as response:
|
||||||
|
text = await response.text()
|
||||||
|
|
||||||
|
soup = BeautifulSoup(text, "html.parser")
|
||||||
|
|
||||||
|
# Get economic indicators
|
||||||
|
indicators = []
|
||||||
|
calendar_table = soup.find("table", class_="table")
|
||||||
|
if calendar_table:
|
||||||
|
for row in calendar_table.find_all("tr")[
|
||||||
|
1:
|
||||||
|
]: # Skip header
|
||||||
|
try:
|
||||||
|
cols = row.find_all("td")
|
||||||
|
indicator = {
|
||||||
|
"country": cols[0].text.strip(),
|
||||||
|
"indicator": cols[1].text.strip(),
|
||||||
|
"actual": cols[2].text.strip(),
|
||||||
|
"previous": cols[3].text.strip(),
|
||||||
|
"consensus": cols[4].text.strip(),
|
||||||
|
}
|
||||||
|
indicators.append(indicator)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {"indicators": indicators}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error fetching Trading Economics data: {e}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def fetch_dailyfx_data(self) -> Dict:
|
||||||
|
"""Scrape market analysis from DailyFX"""
|
||||||
|
try:
|
||||||
|
url = "https://www.dailyfx.com/market-news"
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url, headers=headers
|
||||||
|
) as response:
|
||||||
|
text = await response.text()
|
||||||
|
|
||||||
|
soup = BeautifulSoup(text, "html.parser")
|
||||||
|
|
||||||
|
# Get market news and analysis
|
||||||
|
news = []
|
||||||
|
articles = soup.find_all("article", class_="dfx-article")
|
||||||
|
for article in articles[:10]: # Get latest 10 articles
|
||||||
|
try:
|
||||||
|
news_item = {
|
||||||
|
"title": article.find("h3").text.strip(),
|
||||||
|
"summary": article.find("p").text.strip(),
|
||||||
|
"currency": article.get(
|
||||||
|
"data-currency", "General"
|
||||||
|
),
|
||||||
|
"timestamp": article.find("time").get(
|
||||||
|
"datetime"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
news.append(news_item)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {"news": news}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching DailyFX data: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def fetch_all_data(self) -> Dict:
|
||||||
|
"""Fetch and combine all forex data sources"""
|
||||||
|
try:
|
||||||
|
# Fetch data from all sources concurrently
|
||||||
|
rates, ff_data, te_data, dx_data = await asyncio.gather(
|
||||||
|
self.fetch_ecb_rates(),
|
||||||
|
self.fetch_forex_factory_data(),
|
||||||
|
self.fetch_tradingeconomics_data(),
|
||||||
|
self.fetch_dailyfx_data(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all data
|
||||||
|
market_data = {
|
||||||
|
"exchange_rates": rates,
|
||||||
|
"calendar": ff_data.get("calendar", []),
|
||||||
|
"economic_indicators": te_data.get("indicators", []),
|
||||||
|
"market_news": dx_data.get("news", []),
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return market_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching all data: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# Rest of the ForexForestSystem class remains the same...
|
||||||
|
|
||||||
|
# (Previous ForexDataFeed class code remains the same...)
|
||||||
|
|
||||||
|
# Specialized Agent Prompts
|
||||||
|
TECHNICAL_ANALYST_PROMPT = """You are an expert forex technical analyst agent.
|
||||||
|
Your responsibilities:
|
||||||
|
1. Analyze real-time exchange rate data for patterns and trends
|
||||||
|
2. Calculate cross-rates and currency correlations
|
||||||
|
3. Generate trading signals based on price action
|
||||||
|
4. Monitor market volatility and momentum
|
||||||
|
5. Identify key support and resistance levels
|
||||||
|
|
||||||
|
Data Format:
|
||||||
|
- You will receive exchange rates from ECB and calculated cross-rates
|
||||||
|
- Focus on major currency pairs and their relationships
|
||||||
|
- Consider market volatility and trading volumes
|
||||||
|
|
||||||
|
Output Format:
|
||||||
|
{
|
||||||
|
"analysis_type": "technical",
|
||||||
|
"timestamp": "ISO timestamp",
|
||||||
|
"signals": [
|
||||||
|
{
|
||||||
|
"pair": "Currency pair",
|
||||||
|
"trend": "bullish/bearish/neutral",
|
||||||
|
"strength": 1-10,
|
||||||
|
"key_levels": {"support": [], "resistance": []},
|
||||||
|
"recommendation": "buy/sell/hold"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
FUNDAMENTAL_ANALYST_PROMPT = """You are an expert forex fundamental analyst agent.
|
||||||
|
Your responsibilities:
|
||||||
|
1. Analyze economic calendar events and their impact
|
||||||
|
2. Evaluate economic indicators from Trading Economics
|
||||||
|
3. Assess market news and sentiment from DailyFX
|
||||||
|
4. Monitor central bank actions and policies
|
||||||
|
5. Track geopolitical events affecting currencies
|
||||||
|
|
||||||
|
Data Format:
|
||||||
|
- Economic calendar events with impact levels
|
||||||
|
- Latest economic indicators and previous values
|
||||||
|
- Market news and analysis from reliable sources
|
||||||
|
- Central bank statements and policy changes
|
||||||
|
|
||||||
|
Output Format:
|
||||||
|
{
|
||||||
|
"analysis_type": "fundamental",
|
||||||
|
"timestamp": "ISO timestamp",
|
||||||
|
"assessments": [
|
||||||
|
{
|
||||||
|
"currency": "Currency code",
|
||||||
|
"economic_outlook": "positive/negative/neutral",
|
||||||
|
"key_events": [],
|
||||||
|
"impact_score": 1-10,
|
||||||
|
"bias": "bullish/bearish/neutral"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
MARKET_SENTIMENT_PROMPT = """You are an expert market sentiment analysis agent.
|
||||||
|
Your responsibilities:
|
||||||
|
1. Analyze news sentiment from DailyFX articles
|
||||||
|
2. Track market positioning and bias
|
||||||
|
3. Monitor risk sentiment and market fear/greed
|
||||||
|
4. Identify potential market drivers
|
||||||
|
5. Detect sentiment shifts and extremes
|
||||||
|
|
||||||
|
Data Format:
|
||||||
|
- Market news and analysis articles
|
||||||
|
- Trading sentiment indicators
|
||||||
|
- Risk event calendar
|
||||||
|
- Market commentary and analysis
|
||||||
|
|
||||||
|
Output Format:
|
||||||
|
{
|
||||||
|
"analysis_type": "sentiment",
|
||||||
|
"timestamp": "ISO timestamp",
|
||||||
|
"sentiment_data": [
|
||||||
|
{
|
||||||
|
"pair": "Currency pair",
|
||||||
|
"sentiment": "risk-on/risk-off",
|
||||||
|
"strength": 1-10,
|
||||||
|
"key_drivers": [],
|
||||||
|
"outlook": "positive/negative/neutral"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
STRATEGY_COORDINATOR_PROMPT = """You are the lead forex strategy coordination agent.
|
||||||
|
Your responsibilities:
|
||||||
|
1. Synthesize technical, fundamental, and sentiment analysis
|
||||||
|
2. Generate final trading recommendations
|
||||||
|
3. Manage risk exposure and position sizing
|
||||||
|
4. Coordinate entry and exit points
|
||||||
|
5. Monitor open positions and adjust strategies
|
||||||
|
|
||||||
|
Data Format:
|
||||||
|
- Analysis from technical, fundamental, and sentiment agents
|
||||||
|
- Current market rates and conditions
|
||||||
|
- Economic calendar and news events
|
||||||
|
- Risk parameters and exposure limits
|
||||||
|
|
||||||
|
Output Format:
|
||||||
|
{
|
||||||
|
"analysis_type": "strategy",
|
||||||
|
"timestamp": "ISO timestamp",
|
||||||
|
"recommendations": [
|
||||||
|
{
|
||||||
|
"pair": "Currency pair",
|
||||||
|
"action": "buy/sell/hold",
|
||||||
|
"confidence": 1-10,
|
||||||
|
"entry_points": [],
|
||||||
|
"stop_loss": float,
|
||||||
|
"take_profit": float,
|
||||||
|
"rationale": "string"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
|
||||||
|
class ForexForestSystem:
|
||||||
|
"""Main system coordinating the forest swarm and data feeds"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the forex forest system"""
|
||||||
|
self.data_feed = ForexDataFeed()
|
||||||
|
|
||||||
|
# Create Technical Analysis Tree
|
||||||
|
technical_agents = [
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=TECHNICAL_ANALYST_PROMPT,
|
||||||
|
agent_name="Price Action Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=TECHNICAL_ANALYST_PROMPT,
|
||||||
|
agent_name="Cross Rate Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=TECHNICAL_ANALYST_PROMPT,
|
||||||
|
agent_name="Volatility Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create Fundamental Analysis Tree
|
||||||
|
fundamental_agents = [
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=FUNDAMENTAL_ANALYST_PROMPT,
|
||||||
|
agent_name="Economic Data Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=FUNDAMENTAL_ANALYST_PROMPT,
|
||||||
|
agent_name="News Impact Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=FUNDAMENTAL_ANALYST_PROMPT,
|
||||||
|
agent_name="Central Bank Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create Sentiment Analysis Tree
|
||||||
|
sentiment_agents = [
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=MARKET_SENTIMENT_PROMPT,
|
||||||
|
agent_name="News Sentiment Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=MARKET_SENTIMENT_PROMPT,
|
||||||
|
agent_name="Risk Sentiment Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=MARKET_SENTIMENT_PROMPT,
|
||||||
|
agent_name="Market Positioning Analyst",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create Strategy Coordination Tree
|
||||||
|
strategy_agents = [
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=STRATEGY_COORDINATOR_PROMPT,
|
||||||
|
agent_name="Lead Strategy Coordinator",
|
||||||
|
model_name="gpt-4",
|
||||||
|
temperature=0.5,
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=STRATEGY_COORDINATOR_PROMPT,
|
||||||
|
agent_name="Risk Manager",
|
||||||
|
model_name="gpt-4",
|
||||||
|
temperature=0.5,
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt=STRATEGY_COORDINATOR_PROMPT,
|
||||||
|
agent_name="Position Manager",
|
||||||
|
model_name="gpt-4",
|
||||||
|
temperature=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create trees
|
||||||
|
self.technical_tree = Tree(
|
||||||
|
tree_name="Technical Analysis", agents=technical_agents
|
||||||
|
)
|
||||||
|
self.fundamental_tree = Tree(
|
||||||
|
tree_name="Fundamental Analysis",
|
||||||
|
agents=fundamental_agents,
|
||||||
|
)
|
||||||
|
self.sentiment_tree = Tree(
|
||||||
|
tree_name="Sentiment Analysis", agents=sentiment_agents
|
||||||
|
)
|
||||||
|
self.strategy_tree = Tree(
|
||||||
|
tree_name="Strategy Coordination", agents=strategy_agents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create forest swarm
|
||||||
|
self.forest = ForestSwarm(
|
||||||
|
trees=[
|
||||||
|
self.technical_tree,
|
||||||
|
self.fundamental_tree,
|
||||||
|
self.sentiment_tree,
|
||||||
|
self.strategy_tree,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Forex Forest System initialized successfully")
|
||||||
|
|
||||||
|
async def prepare_analysis_task(self) -> str:
|
||||||
|
"""Prepare the analysis task with real-time data"""
|
||||||
|
try:
|
||||||
|
market_data = await self.data_feed.fetch_all_data()
|
||||||
|
|
||||||
|
task = {
|
||||||
|
"action": "analyze_forex_markets",
|
||||||
|
"market_data": market_data,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"analysis_required": [
|
||||||
|
"technical",
|
||||||
|
"fundamental",
|
||||||
|
"sentiment",
|
||||||
|
"strategy",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(task, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing analysis task: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def run_analysis_cycle(self) -> Dict:
|
||||||
|
"""Run a complete analysis cycle with the forest swarm"""
|
||||||
|
try:
|
||||||
|
# Prepare task with real-time data
|
||||||
|
task = await self.prepare_analysis_task()
|
||||||
|
|
||||||
|
# Run forest swarm analysis
|
||||||
|
result = self.forest.run(task)
|
||||||
|
|
||||||
|
# Parse and validate results
|
||||||
|
analysis = (
|
||||||
|
json.loads(result)
|
||||||
|
if isinstance(result, str)
|
||||||
|
else result
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Analysis cycle completed successfully")
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in analysis cycle: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def monitor_markets(self, interval_seconds: int = 300):
|
||||||
|
"""Continuously monitor markets and run analysis"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Run analysis cycle
|
||||||
|
analysis = await self.run_analysis_cycle()
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
logger.info("Market analysis completed")
|
||||||
|
logger.debug(
|
||||||
|
f"Analysis results: {json.dumps(analysis, indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process any trading signals
|
||||||
|
if "recommendations" in analysis:
|
||||||
|
await self.process_trading_signals(
|
||||||
|
analysis["recommendations"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for next interval
|
||||||
|
await asyncio.sleep(interval_seconds)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in market monitoring: {e}")
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
async def process_trading_signals(
|
||||||
|
self, recommendations: List[Dict]
|
||||||
|
):
|
||||||
|
"""Process and log trading signals from analysis"""
|
||||||
|
try:
|
||||||
|
for rec in recommendations:
|
||||||
|
logger.info(
|
||||||
|
f"Trading Signal: {rec['pair']} - {rec['action']}"
|
||||||
|
)
|
||||||
|
logger.info(f"Confidence: {rec['confidence']}/10")
|
||||||
|
logger.info(f"Entry Points: {rec['entry_points']}")
|
||||||
|
logger.info(f"Stop Loss: {rec['stop_loss']}")
|
||||||
|
logger.info(f"Take Profit: {rec['take_profit']}")
|
||||||
|
logger.info(f"Rationale: {rec['rationale']}")
|
||||||
|
logger.info("-" * 50)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing trading signals: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
async def main():
|
||||||
|
"""Main function to run the Forex Forest System"""
|
||||||
|
try:
|
||||||
|
system = ForexForestSystem()
|
||||||
|
await system.monitor_markets()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set up asyncio event loop and run the system
|
||||||
|
asyncio.run(main())
|
@ -1,618 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
|
||||||
import numpy as np
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Tuple, Dict
|
|
||||||
import math
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TransformerConfig:
|
|
||||||
"""Configuration class for MoE Transformer model parameters."""
|
|
||||||
|
|
||||||
vocab_size: int = 50257
|
|
||||||
hidden_size: int = 768
|
|
||||||
num_attention_heads: int = 12
|
|
||||||
num_expert_layers: int = 4
|
|
||||||
num_experts: int = 8
|
|
||||||
expert_capacity: int = 32
|
|
||||||
max_position_embeddings: int = 1024
|
|
||||||
dropout_prob: float = 0.1
|
|
||||||
layer_norm_epsilon: float = 1e-5
|
|
||||||
initializer_range: float = 0.02
|
|
||||||
num_query_groups: int = 4 # For multi-query attention
|
|
||||||
|
|
||||||
|
|
||||||
class ExpertLayer(nn.Module):
|
|
||||||
"""Individual expert neural network."""
|
|
||||||
|
|
||||||
def __init__(self, config: TransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Linear(
|
|
||||||
config.hidden_size, 4 * config.hidden_size
|
|
||||||
)
|
|
||||||
self.fc2 = nn.Linear(
|
|
||||||
4 * config.hidden_size, config.hidden_size
|
|
||||||
)
|
|
||||||
self.activation = nn.GELU()
|
|
||||||
self.dropout = nn.Dropout(config.dropout_prob)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
x = self.dropout(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MixtureOfExperts(nn.Module):
|
|
||||||
"""Mixture of Experts layer with dynamic routing."""
|
|
||||||
|
|
||||||
def __init__(self, config: TransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.num_experts = config.num_experts
|
|
||||||
self.expert_capacity = config.expert_capacity
|
|
||||||
|
|
||||||
# Create expert networks
|
|
||||||
self.experts = nn.ModuleList(
|
|
||||||
[ExpertLayer(config) for _ in range(config.num_experts)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Router network
|
|
||||||
self.router = nn.Linear(
|
|
||||||
config.hidden_size, config.num_experts
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
|
||||||
"""Route inputs to experts and combine outputs."""
|
|
||||||
batch_size, seq_len, hidden_size = x.shape
|
|
||||||
|
|
||||||
# Calculate routing probabilities
|
|
||||||
router_logits = self.router(x)
|
|
||||||
routing_weights = F.softmax(router_logits, dim=-1)
|
|
||||||
|
|
||||||
# Select top-k experts
|
|
||||||
top_k = 2
|
|
||||||
gates, indices = torch.topk(routing_weights, top_k, dim=-1)
|
|
||||||
gates = F.softmax(gates, dim=-1)
|
|
||||||
|
|
||||||
# Process inputs through selected experts
|
|
||||||
final_output = torch.zeros_like(x)
|
|
||||||
router_load = torch.zeros(self.num_experts, device=x.device)
|
|
||||||
|
|
||||||
for i in range(top_k):
|
|
||||||
expert_index = indices[..., i]
|
|
||||||
gate = gates[..., i : i + 1]
|
|
||||||
|
|
||||||
# Count expert assignments
|
|
||||||
for j in range(self.num_experts):
|
|
||||||
router_load[j] += (expert_index == j).float().sum()
|
|
||||||
|
|
||||||
# Process through selected experts
|
|
||||||
for j in range(self.num_experts):
|
|
||||||
mask = expert_index == j
|
|
||||||
if not mask.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
expert_input = x[mask]
|
|
||||||
expert_output = self.experts[j](expert_input)
|
|
||||||
final_output[mask] += gate[mask] * expert_output
|
|
||||||
|
|
||||||
aux_loss = router_load.float().var() / (
|
|
||||||
router_load.float().mean() ** 2
|
|
||||||
)
|
|
||||||
|
|
||||||
return final_output, {"load_balancing_loss": aux_loss}
|
|
||||||
|
|
||||||
|
|
||||||
class MultiQueryAttention(nn.Module):
|
|
||||||
"""Multi-Query Attention mechanism with proper multi-query group handling."""
|
|
||||||
|
|
||||||
def __init__(self, config: TransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
|
||||||
self.num_query_groups = config.num_query_groups
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.head_dim = (
|
|
||||||
config.hidden_size // config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query projection maintains full head dimension
|
|
||||||
self.q_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Key and value projections use reduced number of heads (query groups)
|
|
||||||
self.k_proj = nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
self.head_dim * config.num_query_groups,
|
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
self.head_dim * config.num_query_groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.dropout_prob)
|
|
||||||
|
|
||||||
# Calculate heads per group for proper reshaping
|
|
||||||
self.heads_per_group = (
|
|
||||||
self.num_attention_heads // self.num_query_groups
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: Tensor,
|
|
||||||
attention_mask: Optional[Tensor] = None,
|
|
||||||
cache: Optional[Dict[str, Tensor]] = None,
|
|
||||||
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
|
||||||
|
|
||||||
# Project queries, keys, and values
|
|
||||||
queries = self.q_proj(hidden_states)
|
|
||||||
keys = self.k_proj(hidden_states)
|
|
||||||
values = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Reshape queries to full number of heads
|
|
||||||
queries = queries.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_attention_heads,
|
|
||||||
self.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape keys and values to number of query groups
|
|
||||||
keys = keys.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_query_groups,
|
|
||||||
self.head_dim,
|
|
||||||
)
|
|
||||||
values = values.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_query_groups,
|
|
||||||
self.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transpose for batch matrix multiplication
|
|
||||||
queries = queries.transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, n_heads, seq_len, head_dim)
|
|
||||||
keys = keys.transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, n_groups, seq_len, head_dim)
|
|
||||||
values = values.transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, n_groups, seq_len, head_dim)
|
|
||||||
|
|
||||||
# Repeat keys and values for each head in the group
|
|
||||||
keys = keys.repeat_interleave(self.heads_per_group, dim=1)
|
|
||||||
values = values.repeat_interleave(self.heads_per_group, dim=1)
|
|
||||||
|
|
||||||
# Compute attention scores
|
|
||||||
scale = 1.0 / math.sqrt(self.head_dim)
|
|
||||||
scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# Expand attention mask to match scores dimensions
|
|
||||||
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
||||||
expanded_mask = expanded_mask.expand(
|
|
||||||
batch_size,
|
|
||||||
self.num_attention_heads,
|
|
||||||
seq_length,
|
|
||||||
seq_length,
|
|
||||||
)
|
|
||||||
mask_value = torch.finfo(scores.dtype).min
|
|
||||||
attention_mask = expanded_mask.eq(0).float() * mask_value
|
|
||||||
scores = scores + attention_mask
|
|
||||||
|
|
||||||
attention_weights = F.softmax(scores, dim=-1)
|
|
||||||
attention_weights = self.dropout(attention_weights)
|
|
||||||
|
|
||||||
# Compute attention output
|
|
||||||
attention_output = torch.matmul(attention_weights, values)
|
|
||||||
attention_output = attention_output.transpose(1, 2)
|
|
||||||
attention_output = attention_output.reshape(
|
|
||||||
batch_size, seq_length, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
return attention_output, None
|
|
||||||
|
|
||||||
|
|
||||||
class MoETransformer(nn.Module):
|
|
||||||
"""
|
|
||||||
Production-grade Transformer model with Mixture of Experts and Multi-Query Attention.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Multi-Query Attention mechanism for efficient inference
|
|
||||||
- Mixture of Experts for dynamic routing and specialization
|
|
||||||
- Real-time weight updates based on input similarity
|
|
||||||
- Built-in logging and monitoring
|
|
||||||
- Type annotations for better code maintainability
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: TransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# Initialize components
|
|
||||||
self.embedding = nn.Embedding(
|
|
||||||
config.vocab_size, config.hidden_size
|
|
||||||
)
|
|
||||||
self.position_embedding = nn.Embedding(
|
|
||||||
config.max_position_embeddings, config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Multi-Query Attention layers
|
|
||||||
self.attention_layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
MultiQueryAttention(config)
|
|
||||||
for _ in range(config.num_expert_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mixture of Experts layers
|
|
||||||
self.moe_layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
MixtureOfExperts(config)
|
|
||||||
for _ in range(config.num_expert_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Layer normalization and dropout
|
|
||||||
self.layer_norm = nn.LayerNorm(
|
|
||||||
config.hidden_size, eps=config.layer_norm_epsilon
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(config.dropout_prob)
|
|
||||||
|
|
||||||
# Output projection
|
|
||||||
self.output_projection = nn.Linear(
|
|
||||||
config.hidden_size, config.vocab_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize weights
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
logger.info("Initialized MoETransformer model")
|
|
||||||
|
|
||||||
def _init_weights(self, module: nn.Module):
|
|
||||||
"""Initialize model weights."""
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
module.weight.data.normal_(
|
|
||||||
mean=0.0, std=self.config.initializer_range
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
isinstance(module, nn.Linear)
|
|
||||||
and module.bias is not None
|
|
||||||
):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
def get_position_embeddings(self, position_ids: Tensor) -> Tensor:
|
|
||||||
"""Generate position embeddings."""
|
|
||||||
return self.position_embedding(position_ids)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
attention_mask: Optional[Tensor] = None,
|
|
||||||
position_ids: Optional[Tensor] = None,
|
|
||||||
cache: Optional[Dict[str, Tensor]] = None,
|
|
||||||
) -> Tuple[Tensor, Dict]:
|
|
||||||
"""
|
|
||||||
Forward pass through the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input token IDs
|
|
||||||
attention_mask: Attention mask for padding
|
|
||||||
position_ids: Position IDs for positioning encoding
|
|
||||||
cache: Cache for key/value states in generation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (logits, auxiliary_outputs)
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(
|
|
||||||
seq_length, dtype=torch.long, device=input_ids.device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(
|
|
||||||
input_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get embeddings
|
|
||||||
inputs_embeds = self.embedding(input_ids)
|
|
||||||
position_embeds = self.get_position_embeddings(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
# Initialize auxiliary outputs
|
|
||||||
aux_outputs = {"moe_losses": []}
|
|
||||||
|
|
||||||
# Process through transformer layers
|
|
||||||
for attention_layer, moe_layer in zip(
|
|
||||||
self.attention_layers, self.moe_layers
|
|
||||||
):
|
|
||||||
# Multi-Query Attention
|
|
||||||
attention_output, _ = attention_layer(
|
|
||||||
hidden_states, attention_mask, cache
|
|
||||||
)
|
|
||||||
hidden_states = self.layer_norm(
|
|
||||||
hidden_states + attention_output
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mixture of Experts
|
|
||||||
moe_output, moe_aux = moe_layer(hidden_states)
|
|
||||||
hidden_states = self.layer_norm(
|
|
||||||
hidden_states + moe_output
|
|
||||||
)
|
|
||||||
aux_outputs["moe_losses"].append(
|
|
||||||
moe_aux["load_balancing_loss"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final output projection
|
|
||||||
logits = self.output_projection(hidden_states)
|
|
||||||
|
|
||||||
return logits, aux_outputs
|
|
||||||
|
|
||||||
def fetch_loss(
|
|
||||||
self,
|
|
||||||
logits: Tensor,
|
|
||||||
labels: Tensor,
|
|
||||||
aux_outputs: Dict,
|
|
||||||
reduction: str = "mean",
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Calculate the total loss including MoE balancing losses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Model output logits
|
|
||||||
labels: Ground truth labels
|
|
||||||
aux_outputs: Auxiliary outputs from forward pass
|
|
||||||
reduction: Loss reduction method
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Total loss
|
|
||||||
"""
|
|
||||||
# Calculate cross entropy loss
|
|
||||||
ce_loss = F.cross_entropy(
|
|
||||||
logits.view(-1, self.config.vocab_size),
|
|
||||||
labels.view(-1),
|
|
||||||
reduction=reduction,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate MoE loss
|
|
||||||
moe_loss = torch.stack(aux_outputs["moe_losses"]).mean()
|
|
||||||
|
|
||||||
# Combine losses
|
|
||||||
total_loss = ce_loss + 0.01 * moe_loss
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"CE Loss: {ce_loss.item():.4f}, "
|
|
||||||
f"MoE Loss: {moe_loss.item():.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return total_loss
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
max_length: int = 100,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Generate text using the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Initial input tokens
|
|
||||||
max_length: Maximum sequence length to generate
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_k: Number of highest probability tokens to keep
|
|
||||||
top_p: Cumulative probability for nucleus sampling
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Generated token IDs
|
|
||||||
"""
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
# Initialize sequence with input_ids
|
|
||||||
generated = input_ids
|
|
||||||
|
|
||||||
# Cache for key-value pairs
|
|
||||||
cache = {}
|
|
||||||
|
|
||||||
for _ in range(max_length):
|
|
||||||
# Get position IDs for current sequence
|
|
||||||
position_ids = torch.arange(
|
|
||||||
generated.shape[1], dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(
|
|
||||||
batch_size, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
logits, _ = self.forward(
|
|
||||||
generated, position_ids=position_ids, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get next token logits
|
|
||||||
next_token_logits = logits[:, -1, :] / temperature
|
|
||||||
|
|
||||||
# Apply top-k filtering
|
|
||||||
if top_k > 0:
|
|
||||||
indices_to_remove = (
|
|
||||||
next_token_logits
|
|
||||||
< torch.topk(next_token_logits, top_k)[0][
|
|
||||||
..., -1, None
|
|
||||||
]
|
|
||||||
)
|
|
||||||
next_token_logits[indices_to_remove] = float("-inf")
|
|
||||||
|
|
||||||
# Apply top-p (nucleus) filtering
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(
|
|
||||||
next_token_logits, descending=True
|
|
||||||
)
|
|
||||||
cumulative_probs = torch.cumsum(
|
|
||||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove tokens with cumulative probability above the threshold
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = (
|
|
||||||
sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
)
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = sorted_indices[
|
|
||||||
sorted_indices_to_remove
|
|
||||||
]
|
|
||||||
next_token_logits[indices_to_remove] = float("-inf")
|
|
||||||
|
|
||||||
# Sample next token
|
|
||||||
probs = F.softmax(next_token_logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
# Append next token to sequence
|
|
||||||
generated = torch.cat((generated, next_token), dim=1)
|
|
||||||
|
|
||||||
# Check for end of sequence token
|
|
||||||
if (next_token == self.config.vocab_size - 1).all():
|
|
||||||
break
|
|
||||||
|
|
||||||
return generated
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize model configuration
|
|
||||||
config = TransformerConfig(
|
|
||||||
vocab_size=50257,
|
|
||||||
hidden_size=768,
|
|
||||||
num_attention_heads=12,
|
|
||||||
num_expert_layers=4,
|
|
||||||
num_experts=8,
|
|
||||||
expert_capacity=32,
|
|
||||||
max_position_embeddings=1024,
|
|
||||||
num_query_groups=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_sample_data(
|
|
||||||
batch_size: int = 8,
|
|
||||||
seq_length: int = 512,
|
|
||||||
vocab_size: int = 50257,
|
|
||||||
) -> DataLoader:
|
|
||||||
"""Create sample data for demonstration."""
|
|
||||||
# Create random input sequences
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, vocab_size, (100, seq_length) # 100 samples
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create target sequences (shifted by 1)
|
|
||||||
labels = torch.randint(0, vocab_size, (100, seq_length))
|
|
||||||
|
|
||||||
# Create attention masks (1 for real tokens, 0 for padding)
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
|
|
||||||
# Create dataset and dataloader
|
|
||||||
dataset = TensorDataset(input_ids, attention_mask, labels)
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset, batch_size=batch_size, shuffle=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
|
|
||||||
def train_step(
|
|
||||||
model: MoETransformer,
|
|
||||||
batch: tuple,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
) -> float:
|
|
||||||
"""Execute single training step."""
|
|
||||||
model.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Unpack batch
|
|
||||||
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
logits, aux_outputs = model(
|
|
||||||
input_ids=input_ids, attention_mask=attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate loss
|
|
||||||
loss = model.fetch_loss(logits, labels, aux_outputs)
|
|
||||||
|
|
||||||
# Backward pass
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
return loss.item()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Set device
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
logger.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = MoETransformer(config).to(device)
|
|
||||||
logger.info("Model initialized")
|
|
||||||
|
|
||||||
# Setup optimizer
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
model.parameters(), lr=1e-4, weight_decay=0.01
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare data
|
|
||||||
dataloader = prepare_sample_data()
|
|
||||||
logger.info("Data prepared")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
num_epochs = 3
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
epoch_losses = []
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(dataloader):
|
|
||||||
loss = train_step(model, batch, optimizer, device)
|
|
||||||
epoch_losses.append(loss)
|
|
||||||
|
|
||||||
if batch_idx % 10 == 0:
|
|
||||||
logger.info(
|
|
||||||
f"Epoch {epoch+1}/{num_epochs} "
|
|
||||||
f"Batch {batch_idx}/{len(dataloader)} "
|
|
||||||
f"Loss: {loss:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
avg_loss = np.mean(epoch_losses)
|
|
||||||
logger.info(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
|
|
||||||
|
|
||||||
# Generation example
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
# Prepare input prompt
|
|
||||||
prompt = torch.randint(0, config.vocab_size, (1, 10)).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate sequence
|
|
||||||
generated = model.generate(
|
|
||||||
input_ids=prompt,
|
|
||||||
max_length=50,
|
|
||||||
temperature=0.7,
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.9,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Generated sequence shape: {generated.shape}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -0,0 +1,333 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.v1 import validator
|
||||||
|
from loguru import logger
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from swarm_models import OpenAIFunctionCaller, OpenAIChat
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.swarm_router import SwarmRouter
|
||||||
|
from swarms.structs.agents_available import showcase_available_agents
|
||||||
|
|
||||||
|
|
||||||
|
BOSS_SYSTEM_PROMPT = """
|
||||||
|
Manage a swarm of worker agents to efficiently serve the user by deciding whether to create new agents or delegate tasks. Ensure operations are efficient and effective.
|
||||||
|
|
||||||
|
### Instructions:
|
||||||
|
|
||||||
|
1. **Task Assignment**:
|
||||||
|
- Analyze available worker agents when a task is presented.
|
||||||
|
- Delegate tasks to existing agents with clear, direct, and actionable instructions if an appropriate agent is available.
|
||||||
|
- If no suitable agent exists, create a new agent with a fitting system prompt to handle the task.
|
||||||
|
|
||||||
|
2. **Agent Creation**:
|
||||||
|
- Name agents according to the task they are intended to perform (e.g., "Twitter Marketing Agent").
|
||||||
|
- Provide each new agent with a concise and clear system prompt that includes its role, objectives, and any tools it can utilize.
|
||||||
|
|
||||||
|
3. **Efficiency**:
|
||||||
|
- Minimize redundancy and maximize task completion speed.
|
||||||
|
- Avoid unnecessary agent creation if an existing agent can fulfill the task.
|
||||||
|
|
||||||
|
4. **Communication**:
|
||||||
|
- Be explicit in task delegation instructions to avoid ambiguity and ensure effective task execution.
|
||||||
|
- Require agents to report back on task completion or encountered issues.
|
||||||
|
|
||||||
|
5. **Reasoning and Decisions**:
|
||||||
|
- Offer brief reasoning when selecting or creating agents to maintain transparency.
|
||||||
|
- Avoid using an agent if unnecessary, with a clear explanation if no agents are suitable for a task.
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
|
||||||
|
Present your plan in clear, bullet-point format or short concise paragraphs, outlining task assignment, agent creation, efficiency strategies, and communication protocols.
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
|
||||||
|
- Preserve transparency by always providing reasoning for task-agent assignments and creation.
|
||||||
|
- Ensure instructions to agents are unambiguous to minimize error.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Configuration for an individual agent in a swarm"""
|
||||||
|
|
||||||
|
name: str = Field(
|
||||||
|
description="The name of the agent", example="Research-Agent"
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
description="A description of the agent's purpose and capabilities",
|
||||||
|
example="Agent responsible for researching and gathering information",
|
||||||
|
)
|
||||||
|
system_prompt: str = Field(
|
||||||
|
description="The system prompt that defines the agent's behavior",
|
||||||
|
example="You are a research agent. Your role is to gather and analyze information...",
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("name")
|
||||||
|
def validate_name(cls, v):
|
||||||
|
if not v.strip():
|
||||||
|
raise ValueError("Agent name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@validator("system_prompt")
|
||||||
|
def validate_system_prompt(cls, v):
|
||||||
|
if not v.strip():
|
||||||
|
raise ValueError("System prompt cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmConfig(BaseModel):
|
||||||
|
"""Configuration for a swarm of cooperative agents"""
|
||||||
|
|
||||||
|
name: str = Field(
|
||||||
|
description="The name of the swarm",
|
||||||
|
example="Research-Writing-Swarm",
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
description="The description of the swarm's purpose and capabilities",
|
||||||
|
example="A swarm of agents that work together to research topics and write articles",
|
||||||
|
)
|
||||||
|
agents: List[AgentConfig] = Field(
|
||||||
|
description="The list of agents that make up the swarm",
|
||||||
|
min_items=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("agents")
|
||||||
|
def validate_agents(cls, v):
|
||||||
|
if not v:
|
||||||
|
raise ValueError("Swarm must have at least one agent")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSwarmBuilder:
|
||||||
|
"""A class that automatically builds and manages swarms of AI agents with enhanced error handling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model_name: str = "gpt-4",
|
||||||
|
):
|
||||||
|
self.name = name or "DefaultSwarm"
|
||||||
|
self.description = description or "Generic AI Agent Swarm"
|
||||||
|
self.verbose = verbose
|
||||||
|
self.agents_pool = []
|
||||||
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI API key must be provided either through initialization or environment variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initialized AutoSwarmBuilder",
|
||||||
|
extra={
|
||||||
|
"swarm_name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"model": self.model_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize OpenAI chat model
|
||||||
|
try:
|
||||||
|
self.chat_model = OpenAIChat(
|
||||||
|
openai_api_key=self.api_key,
|
||||||
|
model_name=self.model_name,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to initialize OpenAI chat model: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
)
|
||||||
|
def run(self, task: str, image_url: Optional[str] = None) -> str:
|
||||||
|
"""Run the swarm on a given task with error handling and retries."""
|
||||||
|
if not task or not task.strip():
|
||||||
|
raise ValueError("Task cannot be empty")
|
||||||
|
|
||||||
|
logger.info("Starting swarm execution", extra={"task": task})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create agents for the task
|
||||||
|
agents = self._create_agents(task, image_url)
|
||||||
|
if not agents:
|
||||||
|
raise ValueError(
|
||||||
|
"No agents were created for the task"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the task through the swarm router
|
||||||
|
logger.info(
|
||||||
|
"Routing task through swarm",
|
||||||
|
extra={"num_agents": len(agents)},
|
||||||
|
)
|
||||||
|
output = self.swarm_router(agents, task, image_url)
|
||||||
|
|
||||||
|
logger.info("Swarm execution completed successfully")
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error during swarm execution: {str(e)}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_agents(
|
||||||
|
self, task: str, image_url: Optional[str] = None
|
||||||
|
) -> List[Agent]:
|
||||||
|
"""Create the necessary agents for a task with enhanced error handling."""
|
||||||
|
logger.info("Creating agents for task", extra={"task": task})
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = OpenAIFunctionCaller(
|
||||||
|
system_prompt=BOSS_SYSTEM_PROMPT,
|
||||||
|
api_key=self.api_key,
|
||||||
|
temperature=0.1,
|
||||||
|
base_model=SwarmConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
agents_config = model.run(task)
|
||||||
|
print(f"{agents_config}")
|
||||||
|
|
||||||
|
if isinstance(agents_config, dict):
|
||||||
|
agents_config = SwarmConfig(**agents_config)
|
||||||
|
|
||||||
|
# Update swarm configuration
|
||||||
|
self.name = agents_config.name
|
||||||
|
self.description = agents_config.description
|
||||||
|
|
||||||
|
# Create agents from configuration
|
||||||
|
agents = []
|
||||||
|
for agent_config in agents_config.agents:
|
||||||
|
if isinstance(agent_config, dict):
|
||||||
|
agent_config = AgentConfig(**agent_config)
|
||||||
|
|
||||||
|
agent = self.build_agent(
|
||||||
|
agent_name=agent_config.name,
|
||||||
|
agent_description=agent_config.description,
|
||||||
|
agent_system_prompt=agent_config.system_prompt,
|
||||||
|
)
|
||||||
|
agents.append(agent)
|
||||||
|
|
||||||
|
# Add available agents showcase to system prompts
|
||||||
|
agents_available = showcase_available_agents(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
agents=agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in agents:
|
||||||
|
agent.system_prompt += "\n" + agents_available
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Successfully created agents",
|
||||||
|
extra={"num_agents": len(agents)},
|
||||||
|
)
|
||||||
|
return agents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error creating agents: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def build_agent(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
agent_description: str,
|
||||||
|
agent_system_prompt: str,
|
||||||
|
) -> Agent:
|
||||||
|
"""Build a single agent with enhanced error handling."""
|
||||||
|
logger.info(
|
||||||
|
"Building agent", extra={"agent_name": agent_name}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = Agent(
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
system_prompt=agent_system_prompt,
|
||||||
|
llm=self.chat_model,
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
verbose=self.verbose,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
saved_state_path=f"states/{agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
||||||
|
user_name="swarms_corp",
|
||||||
|
retry_attempts=3,
|
||||||
|
context_length=200000,
|
||||||
|
return_step_meta=False,
|
||||||
|
output_type="str",
|
||||||
|
streaming_on=False,
|
||||||
|
auto_generate_prompt=True,
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error building agent: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
)
|
||||||
|
def swarm_router(
|
||||||
|
self,
|
||||||
|
agents: List[Agent],
|
||||||
|
task: str,
|
||||||
|
image_url: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Route tasks between agents in the swarm with error handling and retries."""
|
||||||
|
logger.info(
|
||||||
|
"Initializing swarm router",
|
||||||
|
extra={"num_agents": len(agents)},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
swarm_router_instance = SwarmRouter(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
agents=agents,
|
||||||
|
swarm_type="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_task = f"{self.name} {self.description} {task}"
|
||||||
|
result = swarm_router_instance.run(formatted_task)
|
||||||
|
|
||||||
|
logger.info("Successfully completed swarm routing")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in swarm router: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
swarm = AutoSwarmBuilder(
|
||||||
|
name="ChipDesign-Swarm",
|
||||||
|
description="A swarm of specialized AI agents for chip design",
|
||||||
|
api_key="your-api-key", # Optional if set in environment
|
||||||
|
model_name="gpt-4", # Optional, defaults to gpt-4
|
||||||
|
)
|
||||||
|
|
||||||
|
result = swarm.run(
|
||||||
|
"Design a new AI accelerator chip optimized for transformer model inference..."
|
||||||
|
)
|
@ -1,165 +0,0 @@
|
|||||||
|
|
||||||
from swarms import Agent, SwarmRouter
|
|
||||||
|
|
||||||
# Portfolio Analysis Specialist
|
|
||||||
portfolio_analyzer = Agent(
|
|
||||||
agent_name="Portfolio-Analysis-Specialist",
|
|
||||||
system_prompt="""You are an expert portfolio analyst specializing in fund analysis and selection. Your core competencies include:
|
|
||||||
- Comprehensive analysis of mutual funds, ETFs, and index funds
|
|
||||||
- Evaluation of fund performance metrics (expense ratios, tracking error, Sharpe ratio)
|
|
||||||
- Assessment of fund composition and strategy alignment
|
|
||||||
- Risk-adjusted return analysis
|
|
||||||
- Tax efficiency considerations
|
|
||||||
|
|
||||||
For each portfolio analysis:
|
|
||||||
1. Evaluate fund characteristics and performance metrics
|
|
||||||
2. Analyze expense ratios and fee structures
|
|
||||||
3. Assess historical performance and volatility
|
|
||||||
4. Compare funds within same category
|
|
||||||
5. Consider tax implications
|
|
||||||
6. Review fund manager track record and strategy consistency
|
|
||||||
|
|
||||||
Maintain focus on cost-efficiency and alignment with investment objectives.""",
|
|
||||||
model_name="gpt-4o",
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path="portfolio_analyzer.json",
|
|
||||||
user_name="investment_team",
|
|
||||||
retry_attempts=2,
|
|
||||||
context_length=200000,
|
|
||||||
output_type="string",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Asset Allocation Strategist
|
|
||||||
allocation_strategist = Agent(
|
|
||||||
agent_name="Asset-Allocation-Strategist",
|
|
||||||
system_prompt="""You are a specialized asset allocation strategist focused on portfolio construction and optimization. Your expertise includes:
|
|
||||||
- Strategic and tactical asset allocation
|
|
||||||
- Risk tolerance assessment and portfolio matching
|
|
||||||
- Geographic and sector diversification
|
|
||||||
- Rebalancing strategy development
|
|
||||||
- Portfolio optimization using modern portfolio theory
|
|
||||||
|
|
||||||
For each allocation:
|
|
||||||
1. Analyze investor risk tolerance and objectives
|
|
||||||
2. Develop appropriate asset class weights
|
|
||||||
3. Select optimal fund combinations
|
|
||||||
4. Design rebalancing triggers and schedules
|
|
||||||
5. Consider tax-efficient fund placement
|
|
||||||
6. Account for correlation between assets
|
|
||||||
|
|
||||||
Focus on creating well-diversified portfolios aligned with client goals and risk tolerance.""",
|
|
||||||
model_name="gpt-4o",
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path="allocation_strategist.json",
|
|
||||||
user_name="investment_team",
|
|
||||||
retry_attempts=2,
|
|
||||||
context_length=200000,
|
|
||||||
output_type="string",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Risk Management Specialist
|
|
||||||
risk_manager = Agent(
|
|
||||||
agent_name="Risk-Management-Specialist",
|
|
||||||
system_prompt="""You are a risk management specialist focused on portfolio risk assessment and mitigation. Your expertise covers:
|
|
||||||
- Portfolio risk metrics analysis
|
|
||||||
- Downside protection strategies
|
|
||||||
- Correlation analysis between funds
|
|
||||||
- Stress testing and scenario analysis
|
|
||||||
- Market condition impact assessment
|
|
||||||
|
|
||||||
For each portfolio:
|
|
||||||
1. Calculate key risk metrics (Beta, Standard Deviation, etc.)
|
|
||||||
2. Analyze correlation matrices
|
|
||||||
3. Perform stress tests under various scenarios
|
|
||||||
4. Evaluate liquidity risks
|
|
||||||
5. Assess concentration risks
|
|
||||||
6. Monitor factor exposures
|
|
||||||
|
|
||||||
Focus on maintaining appropriate risk levels while maximizing risk-adjusted returns.""",
|
|
||||||
model_name="gpt-4o",
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path="risk_manager.json",
|
|
||||||
user_name="investment_team",
|
|
||||||
retry_attempts=2,
|
|
||||||
context_length=200000,
|
|
||||||
output_type="string",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Portfolio Implementation Specialist
|
|
||||||
implementation_specialist = Agent(
|
|
||||||
agent_name="Portfolio-Implementation-Specialist",
|
|
||||||
system_prompt="""You are a portfolio implementation specialist focused on efficient execution and maintenance. Your responsibilities include:
|
|
||||||
- Fund selection for specific asset class exposure
|
|
||||||
- Tax-efficient implementation strategies
|
|
||||||
- Portfolio rebalancing execution
|
|
||||||
- Trading cost analysis
|
|
||||||
- Cash flow management
|
|
||||||
|
|
||||||
For each implementation:
|
|
||||||
1. Select most efficient funds for desired exposure
|
|
||||||
2. Plan tax-efficient transitions
|
|
||||||
3. Design rebalancing schedule
|
|
||||||
4. Optimize trade execution
|
|
||||||
5. Manage cash positions
|
|
||||||
6. Monitor tracking error
|
|
||||||
|
|
||||||
Maintain focus on minimizing costs and maximizing tax efficiency during implementation.""",
|
|
||||||
model_name="gpt-4o",
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path="implementation_specialist.json",
|
|
||||||
user_name="investment_team",
|
|
||||||
retry_attempts=2,
|
|
||||||
context_length=200000,
|
|
||||||
output_type="string",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Portfolio Monitoring Specialist
|
|
||||||
monitoring_specialist = Agent(
|
|
||||||
agent_name="Portfolio-Monitoring-Specialist",
|
|
||||||
system_prompt="""You are a portfolio monitoring specialist focused on ongoing portfolio oversight and optimization. Your expertise includes:
|
|
||||||
- Regular portfolio performance review
|
|
||||||
- Drift monitoring and rebalancing triggers
|
|
||||||
- Fund changes and replacements
|
|
||||||
- Tax loss harvesting opportunities
|
|
||||||
- Performance attribution analysis
|
|
||||||
|
|
||||||
For each review:
|
|
||||||
1. Track portfolio drift from targets
|
|
||||||
2. Monitor fund performance and changes
|
|
||||||
3. Identify tax loss harvesting opportunities
|
|
||||||
4. Analyze tracking error and expenses
|
|
||||||
5. Review risk metrics evolution
|
|
||||||
6. Generate performance attribution reports
|
|
||||||
|
|
||||||
Ensure continuous alignment with investment objectives while maintaining optimal portfolio efficiency.""",
|
|
||||||
model_name="gpt-4o",
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path="monitoring_specialist.json",
|
|
||||||
user_name="investment_team",
|
|
||||||
retry_attempts=2,
|
|
||||||
context_length=200000,
|
|
||||||
output_type="string",
|
|
||||||
)
|
|
||||||
|
|
||||||
# List of all agents for portfolio management
|
|
||||||
portfolio_agents = [
|
|
||||||
portfolio_analyzer,
|
|
||||||
allocation_strategist,
|
|
||||||
risk_manager,
|
|
||||||
implementation_specialist,
|
|
||||||
monitoring_specialist
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Router
|
|
||||||
router = SwarmRouter(
|
|
||||||
name = "etf-portfolio-management-swarm",
|
|
||||||
description="Creates and suggests an optimal portfolio",
|
|
||||||
agents = portfolio_agents,
|
|
||||||
swarm_type="SequentialWorkflow", # ConcurrentWorkflow
|
|
||||||
max_loops = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
router.run(
|
|
||||||
task = "I have 10,000$ and I want to create a porfolio based on energy, ai, and datacenter companies. high growth."
|
|
||||||
)
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue