parent
40a78c4d06
commit
3b99f3db17
@ -0,0 +1,112 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
BASE_URL = "http://api.swarms.ai:8000"
|
||||||
|
|
||||||
|
|
||||||
|
def make_request(method, endpoint, data=None):
|
||||||
|
"""Helper function to make requests with error handling"""
|
||||||
|
url = f"{BASE_URL}{endpoint}"
|
||||||
|
try:
|
||||||
|
if method == "GET":
|
||||||
|
response = requests.get(url)
|
||||||
|
elif method == "POST":
|
||||||
|
response = requests.post(url, json=data)
|
||||||
|
elif method == "DELETE":
|
||||||
|
response = requests.delete(url)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(
|
||||||
|
f"Error making {method} request to {endpoint}: {str(e)}"
|
||||||
|
)
|
||||||
|
if hasattr(e.response, "text"):
|
||||||
|
print(f"Response text: {e.response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_agent():
|
||||||
|
"""Create a test agent"""
|
||||||
|
data = {
|
||||||
|
"agent_name": "test_agent",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"system_prompt": "You are a helpful assistant",
|
||||||
|
"description": "Test agent",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_loops": 1,
|
||||||
|
"tags": ["test"],
|
||||||
|
}
|
||||||
|
return make_request("POST", "/v1/agent", data)
|
||||||
|
|
||||||
|
|
||||||
|
def list_agents():
|
||||||
|
"""List all agents"""
|
||||||
|
return make_request("GET", "/v1/agents")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion(agent_id):
|
||||||
|
"""Test a completion with the agent"""
|
||||||
|
data = {
|
||||||
|
"prompt": "Say hello!",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"max_tokens": 100,
|
||||||
|
}
|
||||||
|
return make_request("POST", "/v1/agent/completions", data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_metrics(agent_id):
|
||||||
|
"""Get metrics for an agent"""
|
||||||
|
return make_request("GET", f"/v1/agent/{agent_id}/metrics")
|
||||||
|
|
||||||
|
|
||||||
|
def delete_agent(agent_id):
|
||||||
|
"""Delete an agent"""
|
||||||
|
return make_request("DELETE", f"/v1/agent/{agent_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests():
|
||||||
|
print("Starting API tests...")
|
||||||
|
|
||||||
|
# Create an agent
|
||||||
|
print("\n1. Creating agent...")
|
||||||
|
agent_response = create_agent()
|
||||||
|
if not agent_response:
|
||||||
|
print("Failed to create agent")
|
||||||
|
return
|
||||||
|
|
||||||
|
agent_id = agent_response.get("agent_id")
|
||||||
|
print(f"Created agent with ID: {agent_id}")
|
||||||
|
|
||||||
|
# Give the server a moment to process
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
# List agents
|
||||||
|
print("\n2. Listing agents...")
|
||||||
|
agents = list_agents()
|
||||||
|
print(f"Found {len(agents)} agents")
|
||||||
|
|
||||||
|
# Test completion
|
||||||
|
if agent_id:
|
||||||
|
print("\n3. Testing completion...")
|
||||||
|
completion = test_completion(agent_id)
|
||||||
|
if completion:
|
||||||
|
print(
|
||||||
|
f"Completion response: {completion.get('response')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n4. Getting agent metrics...")
|
||||||
|
metrics = get_agent_metrics(agent_id)
|
||||||
|
if metrics:
|
||||||
|
print(f"Agent metrics: {json.dumps(metrics, indent=2)}")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
# print("\n5. Cleaning up - deleting agent...")
|
||||||
|
# delete_result = delete_agent(agent_id)
|
||||||
|
# if delete_result:
|
||||||
|
# print("Successfully deleted agent")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -1,898 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Union, Optional
|
|
||||||
import io
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import struct
|
|
||||||
|
|
||||||
|
|
||||||
from enum import auto
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
import wave
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from loguru import logger
|
|
||||||
from einops import rearrange
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelConfig:
|
|
||||||
"""Configuration for the enhanced BytePredictor model."""
|
|
||||||
|
|
||||||
vocab_size: int = 256 # Standard byte range
|
|
||||||
hidden_size: int = 1024
|
|
||||||
num_layers: int = 12
|
|
||||||
num_key_value_heads: int = 8 # For multi-query attention
|
|
||||||
num_query_heads: int = 32 # More query heads than kv heads
|
|
||||||
dropout: float = 0.1
|
|
||||||
max_sequence_length: int = 8192
|
|
||||||
rope_theta: float = 10000.0
|
|
||||||
layer_norm_eps: float = 1e-5
|
|
||||||
vocab_parallel: bool = False
|
|
||||||
qk_norm: bool = True
|
|
||||||
qk_norm_scale: float = None
|
|
||||||
attention_bias: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MultiQueryAttention(nn.Module):
|
|
||||||
"""Fixed Multi-Query Attention implementation."""
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_query_heads = config.num_query_heads
|
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
|
||||||
self.head_dim = config.hidden_size // config.num_query_heads
|
|
||||||
self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5)
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.num_query_heads * self.head_dim
|
|
||||||
)
|
|
||||||
self.k_proj = nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.num_key_value_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.num_key_value_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
self.o_proj = nn.Linear(
|
|
||||||
config.num_query_heads * self.head_dim, config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.qk_norm = config.qk_norm
|
|
||||||
if self.qk_norm:
|
|
||||||
self.q_norm = nn.LayerNorm(self.head_dim)
|
|
||||||
self.k_norm = nn.LayerNorm(self.head_dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
|
||||||
|
|
||||||
# Project and reshape
|
|
||||||
q = self.q_proj(hidden_states)
|
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Reshape to [seq_len, batch, heads, head_dim]
|
|
||||||
q = q.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_query_heads,
|
|
||||||
self.head_dim,
|
|
||||||
).permute(1, 0, 2, 3)
|
|
||||||
k = k.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_key_value_heads,
|
|
||||||
self.head_dim,
|
|
||||||
).permute(1, 0, 2, 3)
|
|
||||||
v = v.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_key_value_heads,
|
|
||||||
self.head_dim,
|
|
||||||
).permute(1, 0, 2, 3)
|
|
||||||
|
|
||||||
# Apply rotary embeddings
|
|
||||||
# q, k = self.rotary(q, k, seq_length)
|
|
||||||
|
|
||||||
# Apply QK normalization if enabled
|
|
||||||
if self.qk_norm:
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# Handle MQA head expansion
|
|
||||||
if self.num_key_value_heads != self.num_query_heads:
|
|
||||||
k = k.repeat_interleave(
|
|
||||||
self.num_query_heads // self.num_key_value_heads,
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
v = v.repeat_interleave(
|
|
||||||
self.num_query_heads // self.num_key_value_heads,
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute attention
|
|
||||||
# Reshape for matmul: [batch, heads, seq_length, head_dim]
|
|
||||||
q = q.permute(1, 2, 0, 3)
|
|
||||||
k = k.permute(1, 2, 0, 3)
|
|
||||||
v = v.permute(1, 2, 0, 3)
|
|
||||||
|
|
||||||
attn_weights = (
|
|
||||||
torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
||||||
|
|
||||||
output = torch.matmul(attn_weights, v)
|
|
||||||
|
|
||||||
# Reshape back to [batch, seq_length, hidden_size]
|
|
||||||
output = (
|
|
||||||
output.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size, seq_length, -1)
|
|
||||||
)
|
|
||||||
output = self.o_proj(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedBytePredictor(nn.Module):
|
|
||||||
"""Enhanced byte prediction model with state-of-the-art features."""
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# Token embeddings
|
|
||||||
self.tok_embeddings = nn.Embedding(
|
|
||||||
config.vocab_size, config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transformer layers
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.ModuleDict(
|
|
||||||
{
|
|
||||||
"attention": MultiQueryAttention(config),
|
|
||||||
"attention_norm": nn.LayerNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
),
|
|
||||||
"feed_forward": nn.Sequential(
|
|
||||||
nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
4 * config.hidden_size,
|
|
||||||
),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(
|
|
||||||
4 * config.hidden_size,
|
|
||||||
config.hidden_size,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
"feed_forward_norm": nn.LayerNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for _ in range(config.num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(
|
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
|
||||||
)
|
|
||||||
self.output = nn.Linear(
|
|
||||||
config.hidden_size, config.vocab_size, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize weights
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, module: nn.Module) -> None:
|
|
||||||
"""Initialize weights with scaled normal distribution."""
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
||||||
if module.bias is not None:
|
|
||||||
torch.nn.init.zeros_(module.bias)
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Forward pass of the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Tensor of shape (batch_size, sequence_length)
|
|
||||||
attention_mask: Optional attention mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of logits with shape (batch_size, sequence_length, vocab_size)
|
|
||||||
"""
|
|
||||||
hidden_states = self.tok_embeddings(input_ids)
|
|
||||||
|
|
||||||
# Create causal mask if needed
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.triu(
|
|
||||||
torch.ones(
|
|
||||||
(input_ids.size(1), input_ids.size(1)),
|
|
||||||
device=input_ids.device,
|
|
||||||
dtype=torch.bool,
|
|
||||||
),
|
|
||||||
diagonal=1,
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask.masked_fill(
|
|
||||||
attention_mask == 1, float("-inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply transformer layers
|
|
||||||
for layer in self.layers:
|
|
||||||
# Attention block
|
|
||||||
hidden_states = hidden_states + layer["attention"](
|
|
||||||
layer["attention_norm"](hidden_states), attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# Feed-forward block
|
|
||||||
hidden_states = hidden_states + layer["feed_forward"](
|
|
||||||
layer["feed_forward_norm"](hidden_states)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
logits = self.output(hidden_states)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
target_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute cross entropy loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input token ids
|
|
||||||
target_ids: Target token ids
|
|
||||||
attention_mask: Optional attention mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loss value
|
|
||||||
"""
|
|
||||||
logits = self(input_ids, attention_mask)
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
rearrange(logits, "b s v -> (b s) v"),
|
|
||||||
rearrange(target_ids, "b s -> (b s)"),
|
|
||||||
)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_new_tokens: int = 100,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
repetition_penalty: float = 1.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Generate new tokens autoregressively.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Starting sequence
|
|
||||||
max_new_tokens: Number of tokens to generate
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_k: K for top-k sampling
|
|
||||||
top_p: P for nucleus sampling
|
|
||||||
repetition_penalty: Penalty for repeating tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated sequence
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
generated = input_ids.clone()
|
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
|
||||||
if generated.size(1) >= self.config.max_sequence_length:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
logits = self(generated)[:, -1, :]
|
|
||||||
|
|
||||||
# Apply temperature
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
# Apply repetition penalty
|
|
||||||
if repetition_penalty != 1.0:
|
|
||||||
for i in range(batch_size):
|
|
||||||
for token_id in set(generated[i].tolist()):
|
|
||||||
logits[i, token_id] /= repetition_penalty
|
|
||||||
|
|
||||||
# Apply top-k sampling
|
|
||||||
if top_k is not None:
|
|
||||||
indices_to_remove = (
|
|
||||||
logits
|
|
||||||
< torch.topk(logits, top_k)[0][..., -1, None]
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = float("-inf")
|
|
||||||
|
|
||||||
# Apply nucleus (top-p) sampling
|
|
||||||
if top_p is not None:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(
|
|
||||||
logits, descending=True
|
|
||||||
)
|
|
||||||
cumulative_probs = torch.cumsum(
|
|
||||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove tokens with cumulative probability above the threshold
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = (
|
|
||||||
sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
)
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(
|
|
||||||
logits, dtype=torch.bool
|
|
||||||
)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = float("-inf")
|
|
||||||
|
|
||||||
# Sample next token
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
# Append to sequence
|
|
||||||
generated = torch.cat([generated, next_token], dim=1)
|
|
||||||
|
|
||||||
return generated
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_new_tokens: int = 100,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
repetition_penalty: float = 1.0,
|
|
||||||
):
|
|
||||||
tensor_data = self._generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tensor_to_data(tensor_data)
|
|
||||||
|
|
||||||
|
|
||||||
# import torch
|
|
||||||
# from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class DataType(Enum):
|
|
||||||
TEXT = "text"
|
|
||||||
IMAGE = "image"
|
|
||||||
AUDIO = "audio"
|
|
||||||
VIDEO = "video"
|
|
||||||
BINARY = "binary"
|
|
||||||
|
|
||||||
|
|
||||||
class ByteDetokenizer:
|
|
||||||
"""Utility class for converting model output bytes back to original data formats."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
|
|
||||||
"""Convert model output tensor to bytes."""
|
|
||||||
# Convert logits/probabilities to byte values
|
|
||||||
if tensor.dim() > 1:
|
|
||||||
# If we have logits, convert to byte indices
|
|
||||||
byte_indices = tensor.argmax(dim=-1)
|
|
||||||
else:
|
|
||||||
byte_indices = tensor
|
|
||||||
|
|
||||||
# Convert to Python bytes
|
|
||||||
return bytes(
|
|
||||||
byte_indices.cpu().numpy().astype(np.uint8).tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_text(byte_sequence: bytes) -> str:
|
|
||||||
"""Convert bytes to text."""
|
|
||||||
try:
|
|
||||||
return byte_sequence.decode("utf-8")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
# Try with error handling
|
|
||||||
return byte_sequence.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_image(
|
|
||||||
byte_sequence: bytes,
|
|
||||||
mode: str = "RGB",
|
|
||||||
size: Optional[tuple] = None,
|
|
||||||
) -> Image.Image:
|
|
||||||
"""Convert bytes to image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
byte_sequence: Raw image bytes
|
|
||||||
mode: Image mode (RGB, RGBA, L, etc.)
|
|
||||||
size: Optional tuple of (width, height)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to load as-is first (for standard image formats)
|
|
||||||
img = Image.open(io.BytesIO(byte_sequence))
|
|
||||||
if size:
|
|
||||||
img = img.resize(size)
|
|
||||||
return img
|
|
||||||
except:
|
|
||||||
# If failed, assume raw pixel data
|
|
||||||
if not size:
|
|
||||||
# Try to determine size from byte count
|
|
||||||
pixel_count = len(byte_sequence) // len(mode)
|
|
||||||
size = (
|
|
||||||
int(np.sqrt(pixel_count)),
|
|
||||||
int(np.sqrt(pixel_count)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert raw bytes to pixel array
|
|
||||||
pixels = np.frombuffer(byte_sequence, dtype=np.uint8)
|
|
||||||
pixels = pixels.reshape((*size, len(mode)))
|
|
||||||
|
|
||||||
return Image.fromarray(pixels, mode=mode)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_audio(
|
|
||||||
byte_sequence: bytes,
|
|
||||||
sample_rate: int = 44100,
|
|
||||||
channels: int = 2,
|
|
||||||
sample_width: int = 2,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Convert bytes to audio samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
byte_sequence: Raw audio bytes
|
|
||||||
sample_rate: Audio sample rate in Hz
|
|
||||||
channels: Number of audio channels
|
|
||||||
sample_width: Bytes per sample (1, 2, or 4)
|
|
||||||
"""
|
|
||||||
# Determine format string based on sample width
|
|
||||||
format_str = {
|
|
||||||
1: "b", # signed char
|
|
||||||
2: "h", # short
|
|
||||||
4: "i", # int
|
|
||||||
}[sample_width]
|
|
||||||
|
|
||||||
# Unpack bytes to samples
|
|
||||||
sample_count = len(byte_sequence) // (channels * sample_width)
|
|
||||||
samples = struct.unpack(
|
|
||||||
f"<{sample_count * channels}{format_str}", byte_sequence
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape to [samples, channels]
|
|
||||||
return np.array(samples).reshape(-1, channels)
|
|
||||||
|
|
||||||
def decode_data(
|
|
||||||
self,
|
|
||||||
model_output: Union[torch.Tensor, bytes],
|
|
||||||
data_type: DataType,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[str, Image.Image, np.ndarray, bytes]:
|
|
||||||
"""Main method to decode model output to desired format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_output: Either tensor from model or raw bytes
|
|
||||||
data_type: Type of data to decode to
|
|
||||||
**kwargs: Additional parameters for specific decoders
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decoded data in specified format
|
|
||||||
"""
|
|
||||||
# Convert tensor to bytes if needed
|
|
||||||
if isinstance(model_output, torch.Tensor):
|
|
||||||
byte_sequence = self.tensor_to_bytes(model_output)
|
|
||||||
else:
|
|
||||||
byte_sequence = model_output
|
|
||||||
|
|
||||||
# Decode based on type
|
|
||||||
if data_type == DataType.TEXT:
|
|
||||||
return self.decode_text(byte_sequence)
|
|
||||||
elif data_type == DataType.IMAGE:
|
|
||||||
return self.decode_image(byte_sequence, **kwargs)
|
|
||||||
elif data_type == DataType.AUDIO:
|
|
||||||
return self.decode_audio(byte_sequence, **kwargs)
|
|
||||||
elif data_type == DataType.VIDEO:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Video decoding not yet implemented"
|
|
||||||
)
|
|
||||||
else: # BINARY
|
|
||||||
return byte_sequence
|
|
||||||
|
|
||||||
|
|
||||||
# Usage example
|
|
||||||
|
|
||||||
|
|
||||||
class Modality(Enum):
|
|
||||||
TEXT = auto()
|
|
||||||
IMAGE = auto()
|
|
||||||
AUDIO = auto()
|
|
||||||
VIDEO = auto()
|
|
||||||
BINARY = auto()
|
|
||||||
MULTIMODAL = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModalityInfo:
|
|
||||||
"""Information about detected modality."""
|
|
||||||
|
|
||||||
modality: Modality
|
|
||||||
confidence: float
|
|
||||||
metadata: Dict[str, any]
|
|
||||||
sub_modalities: Optional[List["ModalityInfo"]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ModalityDetector:
|
|
||||||
"""Detects data modalities from byte sequences."""
|
|
||||||
|
|
||||||
# Common file signatures (magic numbers)
|
|
||||||
SIGNATURES = {
|
|
||||||
# Images
|
|
||||||
b"\xFF\xD8\xFF": "JPEG",
|
|
||||||
b"\x89PNG\r\n\x1a\n": "PNG",
|
|
||||||
b"GIF87a": "GIF",
|
|
||||||
b"GIF89a": "GIF",
|
|
||||||
b"RIFF": "WEBP",
|
|
||||||
# Audio
|
|
||||||
b"RIFF....WAVE": "WAV",
|
|
||||||
b"ID3": "MP3",
|
|
||||||
b"\xFF\xFB": "MP3",
|
|
||||||
b"OggS": "OGG",
|
|
||||||
# Video
|
|
||||||
b"\x00\x00\x00\x18ftypmp42": "MP4",
|
|
||||||
b"\x00\x00\x00\x1Cftypav01": "MP4",
|
|
||||||
b"\x1A\x45\xDF\xA3": "WEBM",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.magic = magic.Magic(mime=True)
|
|
||||||
|
|
||||||
def _check_text_probability(self, data: bytes) -> float:
|
|
||||||
"""Estimate probability that data is text."""
|
|
||||||
# Check if data is valid UTF-8
|
|
||||||
try:
|
|
||||||
data.decode("utf-8")
|
|
||||||
# Count printable ASCII characters
|
|
||||||
printable = sum(1 for b in data if 32 <= b <= 126)
|
|
||||||
return printable / len(data)
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
|
||||||
"""Check if data is a valid image and extract metadata."""
|
|
||||||
try:
|
|
||||||
with io.BytesIO(data) as bio:
|
|
||||||
img = Image.open(bio)
|
|
||||||
return True, {
|
|
||||||
"format": img.format,
|
|
||||||
"size": img.size,
|
|
||||||
"mode": img.mode,
|
|
||||||
}
|
|
||||||
except:
|
|
||||||
return False, {}
|
|
||||||
|
|
||||||
def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
|
||||||
"""Check if data is valid audio and extract metadata."""
|
|
||||||
try:
|
|
||||||
with io.BytesIO(data) as bio:
|
|
||||||
# Try to parse as WAV
|
|
||||||
with wave.open(bio) as wav:
|
|
||||||
return True, {
|
|
||||||
"channels": wav.getnchannels(),
|
|
||||||
"sample_width": wav.getsampwidth(),
|
|
||||||
"framerate": wav.getframerate(),
|
|
||||||
"frames": wav.getnframes(),
|
|
||||||
}
|
|
||||||
except:
|
|
||||||
# Check for other audio signatures
|
|
||||||
for sig in [b"ID3", b"\xFF\xFB", b"OggS"]:
|
|
||||||
if data.startswith(sig):
|
|
||||||
return True, {"format": "compressed_audio"}
|
|
||||||
return False, {}
|
|
||||||
|
|
||||||
def _detect_boundaries(
|
|
||||||
self, data: bytes
|
|
||||||
) -> List[Tuple[int, int, Modality]]:
|
|
||||||
"""Detect boundaries between different modalities in the data."""
|
|
||||||
boundaries = []
|
|
||||||
current_pos = 0
|
|
||||||
|
|
||||||
while current_pos < len(data):
|
|
||||||
# Look for known signatures
|
|
||||||
for sig, format_type in self.SIGNATURES.items():
|
|
||||||
if data[current_pos:].startswith(sig):
|
|
||||||
# Found a signature, determine its length
|
|
||||||
if format_type in ["JPEG", "PNG", "GIF"]:
|
|
||||||
# Find image end
|
|
||||||
try:
|
|
||||||
with io.BytesIO(
|
|
||||||
data[current_pos:]
|
|
||||||
) as bio:
|
|
||||||
img = Image.open(bio)
|
|
||||||
img.verify()
|
|
||||||
size = bio.tell()
|
|
||||||
boundaries.append(
|
|
||||||
(
|
|
||||||
current_pos,
|
|
||||||
current_pos + size,
|
|
||||||
Modality.IMAGE,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
current_pos += size
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Check for text sections
|
|
||||||
text_prob = self._check_text_probability(
|
|
||||||
data[current_pos : current_pos + 1024]
|
|
||||||
)
|
|
||||||
if text_prob > 0.8:
|
|
||||||
# Look for end of text section
|
|
||||||
end_pos = current_pos + 1
|
|
||||||
while end_pos < len(data):
|
|
||||||
if (
|
|
||||||
self._check_text_probability(
|
|
||||||
data[end_pos : end_pos + 32]
|
|
||||||
)
|
|
||||||
< 0.5
|
|
||||||
):
|
|
||||||
break
|
|
||||||
end_pos += 1
|
|
||||||
boundaries.append(
|
|
||||||
(current_pos, end_pos, Modality.TEXT)
|
|
||||||
)
|
|
||||||
current_pos = end_pos
|
|
||||||
continue
|
|
||||||
|
|
||||||
current_pos += 1
|
|
||||||
|
|
||||||
return boundaries
|
|
||||||
|
|
||||||
def detect_modality(self, data: bytes) -> ModalityInfo:
|
|
||||||
"""Detect modality of byte sequence."""
|
|
||||||
# First check for single modality
|
|
||||||
mime_type = self.magic.from_buffer(data)
|
|
||||||
|
|
||||||
# Check text
|
|
||||||
text_prob = self._check_text_probability(data)
|
|
||||||
if text_prob > 0.9:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.TEXT,
|
|
||||||
confidence=text_prob,
|
|
||||||
metadata={"mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check image
|
|
||||||
is_image, image_meta = self._check_image_validity(data)
|
|
||||||
if is_image:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.IMAGE,
|
|
||||||
confidence=1.0,
|
|
||||||
metadata={**image_meta, "mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check audio
|
|
||||||
is_audio, audio_meta = self._check_audio_validity(data)
|
|
||||||
if is_audio:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.AUDIO,
|
|
||||||
confidence=1.0,
|
|
||||||
metadata={**audio_meta, "mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for multimodal content
|
|
||||||
boundaries = self._detect_boundaries(data)
|
|
||||||
if len(boundaries) > 1:
|
|
||||||
sub_modalities = []
|
|
||||||
for start, end, modality in boundaries:
|
|
||||||
chunk_data = data[start:end]
|
|
||||||
sub_info = self.detect_modality(chunk_data)
|
|
||||||
if sub_info.modality != Modality.BINARY:
|
|
||||||
sub_modalities.append(sub_info)
|
|
||||||
|
|
||||||
if sub_modalities:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.MULTIMODAL,
|
|
||||||
confidence=0.8,
|
|
||||||
metadata={"mime_type": "multipart/mixed"},
|
|
||||||
sub_modalities=sub_modalities,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Default to binary
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.BINARY,
|
|
||||||
confidence=0.5,
|
|
||||||
metadata={"mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
def split_modalities(
|
|
||||||
self, data: bytes
|
|
||||||
) -> List[Tuple[Modality, bytes, Dict]]:
|
|
||||||
"""Split multimodal data into separate modalities."""
|
|
||||||
boundaries = self._detect_boundaries(data)
|
|
||||||
result = []
|
|
||||||
|
|
||||||
for start, end, modality in boundaries:
|
|
||||||
chunk = data[start:end]
|
|
||||||
info = self.detect_modality(chunk)
|
|
||||||
result.append((modality, chunk, info.metadata))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AutoDetectBytesDecoder:
|
|
||||||
"""Decoder that automatically detects and decodes different modalities."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.detector = ModalityDetector()
|
|
||||||
self.text_decoder = ByteDetokenizer() # From previous example
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self, data: bytes
|
|
||||||
) -> Union[str, Image.Image, np.ndarray, List[any]]:
|
|
||||||
"""Automatically detect and decode byte sequence."""
|
|
||||||
info = self.detector.detect_modality(data)
|
|
||||||
|
|
||||||
if info.modality == Modality.MULTIMODAL:
|
|
||||||
# Handle multimodal content
|
|
||||||
parts = self.detector.split_modalities(data)
|
|
||||||
return [
|
|
||||||
self.decode(chunk) for modality, chunk, _ in parts
|
|
||||||
]
|
|
||||||
|
|
||||||
if info.modality == Modality.TEXT:
|
|
||||||
return self.text_decoder.decode_text(data)
|
|
||||||
elif info.modality == Modality.IMAGE:
|
|
||||||
return self.text_decoder.decode_image(data)
|
|
||||||
elif info.modality == Modality.AUDIO:
|
|
||||||
return self.text_decoder.decode_audio(data)
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage
|
|
||||||
# def demo_auto_detection():
|
|
||||||
# """Demonstrate auto modality detection."""
|
|
||||||
# # Create mixed content
|
|
||||||
# text = "Hello, World!".encode('utf-8')
|
|
||||||
|
|
||||||
# # Create a small test image
|
|
||||||
# img = Image.new('RGB', (100, 100), color='red')
|
|
||||||
# img_bytes = io.BytesIO()
|
|
||||||
# img.save(img_bytes, format='PNG')
|
|
||||||
|
|
||||||
# # Combine into multimodal content
|
|
||||||
# mixed_content = text + img_bytes.getvalue()
|
|
||||||
|
|
||||||
# # Initialize decoder
|
|
||||||
# decoder = AutoDetectBytesDecoder()
|
|
||||||
|
|
||||||
# # Decode
|
|
||||||
# result = decoder.decode(mixed_content)
|
|
||||||
|
|
||||||
# if isinstance(result, list):
|
|
||||||
# print("Detected multimodal content:")
|
|
||||||
# for i, part in enumerate(result):
|
|
||||||
# print(f"Part {i+1}: {type(part)}")
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# demo_auto_detection()
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_data(tensor: Tensor):
|
|
||||||
byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor)
|
|
||||||
|
|
||||||
# Initialize auto-detector
|
|
||||||
decoder = AutoDetectBytesDecoder()
|
|
||||||
|
|
||||||
# Decode with automatic detection
|
|
||||||
result = decoder.decode(byte_sequence)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def demo_byte_predictor():
|
|
||||||
"""Demo with smaller dimensions to test."""
|
|
||||||
# Initialize model configuration with adjusted dimensions
|
|
||||||
config = ModelConfig(
|
|
||||||
vocab_size=256,
|
|
||||||
hidden_size=128, # Smaller for testing
|
|
||||||
num_layers=2, # Fewer layers for testing
|
|
||||||
num_key_value_heads=2,
|
|
||||||
num_query_heads=4,
|
|
||||||
dropout=0.1,
|
|
||||||
max_sequence_length=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = EnhancedBytePredictor(config)
|
|
||||||
logger.info("Model initialized")
|
|
||||||
|
|
||||||
# Move to GPU if available
|
|
||||||
device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
model = model.to(device)
|
|
||||||
logger.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
# Create sample input data
|
|
||||||
batch_size = 2
|
|
||||||
seq_length = 16 # Shorter sequence for testing
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, config.vocab_size, (batch_size, seq_length), device=device
|
|
||||||
)
|
|
||||||
logger.info(f"Created input tensor of shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Test forward pass
|
|
||||||
try:
|
|
||||||
logits = model(input_ids)
|
|
||||||
logger.info(
|
|
||||||
f"Forward pass successful! Output shape: {logits.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test loss computation
|
|
||||||
target_ids = torch.randint(
|
|
||||||
0,
|
|
||||||
config.vocab_size,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
loss = model.compute_loss(input_ids, target_ids)
|
|
||||||
logger.info(
|
|
||||||
f"Loss computation successful! Loss value: {loss.item():.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test generation
|
|
||||||
prompt = torch.randint(
|
|
||||||
0,
|
|
||||||
config.vocab_size,
|
|
||||||
(1, 4), # Very short prompt for testing
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
generated = model.generate(
|
|
||||||
prompt, max_new_tokens=8, temperature=0.8, top_k=50
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Generation successful! Generated shape: {generated.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during execution: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Set up logging
|
|
||||||
# logger.remove() # Remove default handler
|
|
||||||
# logger.add(sys.stderr, format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
|
||||||
|
|
||||||
demo_byte_predictor()
|
|
@ -0,0 +1,8 @@
|
|||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
Agent(
|
||||||
|
agent_name="Stock-Analysis-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
).run("What are 5 hft algorithms")
|
@ -0,0 +1,618 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Dict
|
||||||
|
import math
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformerConfig:
|
||||||
|
"""Configuration class for MoE Transformer model parameters."""
|
||||||
|
|
||||||
|
vocab_size: int = 50257
|
||||||
|
hidden_size: int = 768
|
||||||
|
num_attention_heads: int = 12
|
||||||
|
num_expert_layers: int = 4
|
||||||
|
num_experts: int = 8
|
||||||
|
expert_capacity: int = 32
|
||||||
|
max_position_embeddings: int = 1024
|
||||||
|
dropout_prob: float = 0.1
|
||||||
|
layer_norm_epsilon: float = 1e-5
|
||||||
|
initializer_range: float = 0.02
|
||||||
|
num_query_groups: int = 4 # For multi-query attention
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertLayer(nn.Module):
|
||||||
|
"""Individual expert neural network."""
|
||||||
|
|
||||||
|
def __init__(self, config: TransformerConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(
|
||||||
|
config.hidden_size, 4 * config.hidden_size
|
||||||
|
)
|
||||||
|
self.fc2 = nn.Linear(
|
||||||
|
4 * config.hidden_size, config.hidden_size
|
||||||
|
)
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MixtureOfExperts(nn.Module):
|
||||||
|
"""Mixture of Experts layer with dynamic routing."""
|
||||||
|
|
||||||
|
def __init__(self, config: TransformerConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = config.num_experts
|
||||||
|
self.expert_capacity = config.expert_capacity
|
||||||
|
|
||||||
|
# Create expert networks
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[ExpertLayer(config) for _ in range(config.num_experts)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Router network
|
||||||
|
self.router = nn.Linear(
|
||||||
|
config.hidden_size, config.num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||||
|
"""Route inputs to experts and combine outputs."""
|
||||||
|
batch_size, seq_len, hidden_size = x.shape
|
||||||
|
|
||||||
|
# Calculate routing probabilities
|
||||||
|
router_logits = self.router(x)
|
||||||
|
routing_weights = F.softmax(router_logits, dim=-1)
|
||||||
|
|
||||||
|
# Select top-k experts
|
||||||
|
top_k = 2
|
||||||
|
gates, indices = torch.topk(routing_weights, top_k, dim=-1)
|
||||||
|
gates = F.softmax(gates, dim=-1)
|
||||||
|
|
||||||
|
# Process inputs through selected experts
|
||||||
|
final_output = torch.zeros_like(x)
|
||||||
|
router_load = torch.zeros(self.num_experts, device=x.device)
|
||||||
|
|
||||||
|
for i in range(top_k):
|
||||||
|
expert_index = indices[..., i]
|
||||||
|
gate = gates[..., i : i + 1]
|
||||||
|
|
||||||
|
# Count expert assignments
|
||||||
|
for j in range(self.num_experts):
|
||||||
|
router_load[j] += (expert_index == j).float().sum()
|
||||||
|
|
||||||
|
# Process through selected experts
|
||||||
|
for j in range(self.num_experts):
|
||||||
|
mask = expert_index == j
|
||||||
|
if not mask.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
expert_input = x[mask]
|
||||||
|
expert_output = self.experts[j](expert_input)
|
||||||
|
final_output[mask] += gate[mask] * expert_output
|
||||||
|
|
||||||
|
aux_loss = router_load.float().var() / (
|
||||||
|
router_load.float().mean() ** 2
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_output, {"load_balancing_loss": aux_loss}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiQueryAttention(nn.Module):
|
||||||
|
"""Multi-Query Attention mechanism with proper multi-query group handling."""
|
||||||
|
|
||||||
|
def __init__(self, config: TransformerConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.num_query_groups = config.num_query_groups
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_dim = (
|
||||||
|
config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query projection maintains full head dimension
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
config.hidden_size, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Key and value projections use reduced number of heads (query groups)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.head_dim * config.num_query_groups,
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.head_dim * config.num_query_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||||||
|
|
||||||
|
# Calculate heads per group for proper reshaping
|
||||||
|
self.heads_per_group = (
|
||||||
|
self.num_attention_heads // self.num_query_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
cache: Optional[Dict[str, Tensor]] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# Project queries, keys, and values
|
||||||
|
queries = self.q_proj(hidden_states)
|
||||||
|
keys = self.k_proj(hidden_states)
|
||||||
|
values = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Reshape queries to full number of heads
|
||||||
|
queries = queries.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape keys and values to number of query groups
|
||||||
|
keys = keys.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_query_groups,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
values = values.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_query_groups,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transpose for batch matrix multiplication
|
||||||
|
queries = queries.transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, n_heads, seq_len, head_dim)
|
||||||
|
keys = keys.transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, n_groups, seq_len, head_dim)
|
||||||
|
values = values.transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, n_groups, seq_len, head_dim)
|
||||||
|
|
||||||
|
# Repeat keys and values for each head in the group
|
||||||
|
keys = keys.repeat_interleave(self.heads_per_group, dim=1)
|
||||||
|
values = values.repeat_interleave(self.heads_per_group, dim=1)
|
||||||
|
|
||||||
|
# Compute attention scores
|
||||||
|
scale = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Expand attention mask to match scores dimensions
|
||||||
|
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
expanded_mask = expanded_mask.expand(
|
||||||
|
batch_size,
|
||||||
|
self.num_attention_heads,
|
||||||
|
seq_length,
|
||||||
|
seq_length,
|
||||||
|
)
|
||||||
|
mask_value = torch.finfo(scores.dtype).min
|
||||||
|
attention_mask = expanded_mask.eq(0).float() * mask_value
|
||||||
|
scores = scores + attention_mask
|
||||||
|
|
||||||
|
attention_weights = F.softmax(scores, dim=-1)
|
||||||
|
attention_weights = self.dropout(attention_weights)
|
||||||
|
|
||||||
|
# Compute attention output
|
||||||
|
attention_output = torch.matmul(attention_weights, values)
|
||||||
|
attention_output = attention_output.transpose(1, 2)
|
||||||
|
attention_output = attention_output.reshape(
|
||||||
|
batch_size, seq_length, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return attention_output, None
|
||||||
|
|
||||||
|
|
||||||
|
class MoETransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Production-grade Transformer model with Mixture of Experts and Multi-Query Attention.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Multi-Query Attention mechanism for efficient inference
|
||||||
|
- Mixture of Experts for dynamic routing and specialization
|
||||||
|
- Real-time weight updates based on input similarity
|
||||||
|
- Built-in logging and monitoring
|
||||||
|
- Type annotations for better code maintainability
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: TransformerConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.embedding = nn.Embedding(
|
||||||
|
config.vocab_size, config.hidden_size
|
||||||
|
)
|
||||||
|
self.position_embedding = nn.Embedding(
|
||||||
|
config.max_position_embeddings, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multi-Query Attention layers
|
||||||
|
self.attention_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MultiQueryAttention(config)
|
||||||
|
for _ in range(config.num_expert_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mixture of Experts layers
|
||||||
|
self.moe_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MixtureOfExperts(config)
|
||||||
|
for _ in range(config.num_expert_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Layer normalization and dropout
|
||||||
|
self.layer_norm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||||||
|
|
||||||
|
# Output projection
|
||||||
|
self.output_projection = nn.Linear(
|
||||||
|
config.hidden_size, config.vocab_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
logger.info("Initialized MoETransformer model")
|
||||||
|
|
||||||
|
def _init_weights(self, module: nn.Module):
|
||||||
|
"""Initialize model weights."""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
module.weight.data.normal_(
|
||||||
|
mean=0.0, std=self.config.initializer_range
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(module, nn.Linear)
|
||||||
|
and module.bias is not None
|
||||||
|
):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def get_position_embeddings(self, position_ids: Tensor) -> Tensor:
|
||||||
|
"""Generate position embeddings."""
|
||||||
|
return self.position_embedding(position_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
|
cache: Optional[Dict[str, Tensor]] = None,
|
||||||
|
) -> Tuple[Tensor, Dict]:
|
||||||
|
"""
|
||||||
|
Forward pass through the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token IDs
|
||||||
|
attention_mask: Attention mask for padding
|
||||||
|
position_ids: Position IDs for positioning encoding
|
||||||
|
cache: Cache for key/value states in generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (logits, auxiliary_outputs)
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(
|
||||||
|
seq_length, dtype=torch.long, device=input_ids.device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand_as(
|
||||||
|
input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
position_embeds = self.get_position_embeddings(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
|
# Initialize auxiliary outputs
|
||||||
|
aux_outputs = {"moe_losses": []}
|
||||||
|
|
||||||
|
# Process through transformer layers
|
||||||
|
for attention_layer, moe_layer in zip(
|
||||||
|
self.attention_layers, self.moe_layers
|
||||||
|
):
|
||||||
|
# Multi-Query Attention
|
||||||
|
attention_output, _ = attention_layer(
|
||||||
|
hidden_states, attention_mask, cache
|
||||||
|
)
|
||||||
|
hidden_states = self.layer_norm(
|
||||||
|
hidden_states + attention_output
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mixture of Experts
|
||||||
|
moe_output, moe_aux = moe_layer(hidden_states)
|
||||||
|
hidden_states = self.layer_norm(
|
||||||
|
hidden_states + moe_output
|
||||||
|
)
|
||||||
|
aux_outputs["moe_losses"].append(
|
||||||
|
moe_aux["load_balancing_loss"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final output projection
|
||||||
|
logits = self.output_projection(hidden_states)
|
||||||
|
|
||||||
|
return logits, aux_outputs
|
||||||
|
|
||||||
|
def fetch_loss(
|
||||||
|
self,
|
||||||
|
logits: Tensor,
|
||||||
|
labels: Tensor,
|
||||||
|
aux_outputs: Dict,
|
||||||
|
reduction: str = "mean",
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Calculate the total loss including MoE balancing losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Model output logits
|
||||||
|
labels: Ground truth labels
|
||||||
|
aux_outputs: Auxiliary outputs from forward pass
|
||||||
|
reduction: Loss reduction method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Total loss
|
||||||
|
"""
|
||||||
|
# Calculate cross entropy loss
|
||||||
|
ce_loss = F.cross_entropy(
|
||||||
|
logits.view(-1, self.config.vocab_size),
|
||||||
|
labels.view(-1),
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate MoE loss
|
||||||
|
moe_loss = torch.stack(aux_outputs["moe_losses"]).mean()
|
||||||
|
|
||||||
|
# Combine losses
|
||||||
|
total_loss = ce_loss + 0.01 * moe_loss
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"CE Loss: {ce_loss.item():.4f}, "
|
||||||
|
f"MoE Loss: {moe_loss.item():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
max_length: int = 100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Generate text using the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Initial input tokens
|
||||||
|
max_length: Maximum sequence length to generate
|
||||||
|
temperature: Sampling temperature
|
||||||
|
top_k: Number of highest probability tokens to keep
|
||||||
|
top_p: Cumulative probability for nucleus sampling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Generated token IDs
|
||||||
|
"""
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Initialize sequence with input_ids
|
||||||
|
generated = input_ids
|
||||||
|
|
||||||
|
# Cache for key-value pairs
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
for _ in range(max_length):
|
||||||
|
# Get position IDs for current sequence
|
||||||
|
position_ids = torch.arange(
|
||||||
|
generated.shape[1], dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(
|
||||||
|
batch_size, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits, _ = self.forward(
|
||||||
|
generated, position_ids=position_ids, cache=cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get next token logits
|
||||||
|
next_token_logits = logits[:, -1, :] / temperature
|
||||||
|
|
||||||
|
# Apply top-k filtering
|
||||||
|
if top_k > 0:
|
||||||
|
indices_to_remove = (
|
||||||
|
next_token_logits
|
||||||
|
< torch.topk(next_token_logits, top_k)[0][
|
||||||
|
..., -1, None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
next_token_logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Apply top-p (nucleus) filtering
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(
|
||||||
|
next_token_logits, descending=True
|
||||||
|
)
|
||||||
|
cumulative_probs = torch.cumsum(
|
||||||
|
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = (
|
||||||
|
sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
)
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices[
|
||||||
|
sorted_indices_to_remove
|
||||||
|
]
|
||||||
|
next_token_logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Sample next token
|
||||||
|
probs = F.softmax(next_token_logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
# Append next token to sequence
|
||||||
|
generated = torch.cat((generated, next_token), dim=1)
|
||||||
|
|
||||||
|
# Check for end of sequence token
|
||||||
|
if (next_token == self.config.vocab_size - 1).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
return generated
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize model configuration
|
||||||
|
config = TransformerConfig(
|
||||||
|
vocab_size=50257,
|
||||||
|
hidden_size=768,
|
||||||
|
num_attention_heads=12,
|
||||||
|
num_expert_layers=4,
|
||||||
|
num_experts=8,
|
||||||
|
expert_capacity=32,
|
||||||
|
max_position_embeddings=1024,
|
||||||
|
num_query_groups=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_sample_data(
|
||||||
|
batch_size: int = 8,
|
||||||
|
seq_length: int = 512,
|
||||||
|
vocab_size: int = 50257,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""Create sample data for demonstration."""
|
||||||
|
# Create random input sequences
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, vocab_size, (100, seq_length) # 100 samples
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create target sequences (shifted by 1)
|
||||||
|
labels = torch.randint(0, vocab_size, (100, seq_length))
|
||||||
|
|
||||||
|
# Create attention masks (1 for real tokens, 0 for padding)
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
# Create dataset and dataloader
|
||||||
|
dataset = TensorDataset(input_ids, attention_mask, labels)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset, batch_size=batch_size, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def train_step(
|
||||||
|
model: MoETransformer,
|
||||||
|
batch: tuple,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
) -> float:
|
||||||
|
"""Execute single training step."""
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Unpack batch
|
||||||
|
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits, aux_outputs = model(
|
||||||
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
loss = model.fetch_loss(logits, labels, aux_outputs)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set device
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = MoETransformer(config).to(device)
|
||||||
|
logger.info("Model initialized")
|
||||||
|
|
||||||
|
# Setup optimizer
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(), lr=1e-4, weight_decay=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
dataloader = prepare_sample_data()
|
||||||
|
logger.info("Data prepared")
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
num_epochs = 3
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
epoch_losses = []
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(dataloader):
|
||||||
|
loss = train_step(model, batch, optimizer, device)
|
||||||
|
epoch_losses.append(loss)
|
||||||
|
|
||||||
|
if batch_idx % 10 == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch+1}/{num_epochs} "
|
||||||
|
f"Batch {batch_idx}/{len(dataloader)} "
|
||||||
|
f"Loss: {loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_loss = np.mean(epoch_losses)
|
||||||
|
logger.info(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
|
||||||
|
|
||||||
|
# Generation example
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# Prepare input prompt
|
||||||
|
prompt = torch.randint(0, config.vocab_size, (1, 10)).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sequence
|
||||||
|
generated = model.generate(
|
||||||
|
input_ids=prompt,
|
||||||
|
max_length=50,
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generated sequence shape: {generated.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,121 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from loguru import logger
|
|
||||||
from supabase import Client, create_client
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Initialize Supabase client
|
|
||||||
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
|
||||||
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
|
||||||
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
|
||||||
|
|
||||||
# Swarms API URL and headers
|
|
||||||
SWARMS_API_URL = "https://swarms.world/api/add-prompt"
|
|
||||||
SWARMS_API_KEY = os.getenv("SWARMS_API_KEY")
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {SWARMS_API_KEY}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Configure logger
|
|
||||||
logger.add(
|
|
||||||
"fetch_and_publish_prompts.log", rotation="1 MB"
|
|
||||||
) # Log file with rotation
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_publish_prompts():
|
|
||||||
logger.info("Starting to fetch and publish prompts.")
|
|
||||||
|
|
||||||
# Fetch data from Supabase
|
|
||||||
try:
|
|
||||||
response = (
|
|
||||||
supabase.table("swarms_framework_schema")
|
|
||||||
.select("*")
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
rows = response.data
|
|
||||||
logger.info(f"Fetched {len(rows)} rows from Supabase.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to fetch data from Supabase: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Track published prompts to avoid duplicates
|
|
||||||
published_prompts = set()
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
# Extract agent_name and system_prompt
|
|
||||||
data = row.get("data", {})
|
|
||||||
agent_name = data.get("agent_name")
|
|
||||||
system_prompt = data.get("system_prompt")
|
|
||||||
|
|
||||||
# Skip if either is missing or duplicate
|
|
||||||
if not agent_name or not system_prompt:
|
|
||||||
logger.warning(
|
|
||||||
f"Skipping row due to missing agent_name or system_prompt: {row}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
if is_duplicate(system_prompt, published_prompts):
|
|
||||||
logger.info(
|
|
||||||
f"Skipping duplicate prompt for agent: {agent_name}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create the data payload for the marketplace
|
|
||||||
prompt_data = {
|
|
||||||
"name": f"{agent_name} - System Prompt",
|
|
||||||
"prompt": system_prompt,
|
|
||||||
"description": f"System prompt for agent {agent_name}.",
|
|
||||||
"useCases": extract_use_cases(system_prompt),
|
|
||||||
"tags": "agent, system-prompt",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Publish to the marketplace
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
SWARMS_API_URL,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(prompt_data),
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info(
|
|
||||||
f"Successfully published prompt for agent: {agent_name}"
|
|
||||||
)
|
|
||||||
published_prompts.add(system_prompt)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish prompt for agent: {agent_name}. Response: {response.text}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Exception occurred while publishing prompt for agent: {agent_name}. Error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_duplicate(new_prompt, published_prompts):
|
|
||||||
"""Check if the prompt is a duplicate using semantic similarity."""
|
|
||||||
for prompt in published_prompts:
|
|
||||||
similarity = SequenceMatcher(None, new_prompt, prompt).ratio()
|
|
||||||
if (
|
|
||||||
similarity > 0.9
|
|
||||||
): # Threshold for considering prompts as duplicates
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def extract_use_cases(prompt):
|
|
||||||
"""Extract use cases from the prompt by chunking it into meaningful segments."""
|
|
||||||
# This is a simple placeholder; you can use a more advanced method to extract use cases
|
|
||||||
chunks = [prompt[i : i + 50] for i in range(0, len(prompt), 50)]
|
|
||||||
return [
|
|
||||||
{"title": f"Use case {idx+1}", "description": chunk}
|
|
||||||
for idx, chunk in enumerate(chunks)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Main execution
|
|
||||||
fetch_and_publish_prompts()
|
|
@ -1,121 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from loguru import logger
|
|
||||||
from supabase import Client, create_client
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Initialize Supabase client
|
|
||||||
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
|
||||||
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
|
||||||
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
|
||||||
|
|
||||||
# Swarms API URL and headers
|
|
||||||
SWARMS_API_URL = "https://swarms.world/api/add-prompt"
|
|
||||||
SWARMS_API_KEY = os.getenv("SWARMS_API_KEY")
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {SWARMS_API_KEY}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Configure logger
|
|
||||||
logger.add(
|
|
||||||
"fetch_and_publish_prompts.log", rotation="1 MB"
|
|
||||||
) # Log file with rotation
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_publish_prompts():
|
|
||||||
logger.info("Starting to fetch and publish prompts.")
|
|
||||||
|
|
||||||
# Fetch data from Supabase
|
|
||||||
try:
|
|
||||||
response = (
|
|
||||||
supabase.table("swarms_framework_schema")
|
|
||||||
.select("*")
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
rows = response.data
|
|
||||||
logger.info(f"Fetched {len(rows)} rows from Supabase.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to fetch data from Supabase: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Track published prompts to avoid duplicates
|
|
||||||
published_prompts = set()
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
# Extract agent_name and system_prompt
|
|
||||||
data = row.get("data", {})
|
|
||||||
agent_name = data.get("agent_name")
|
|
||||||
system_prompt = data.get("system_prompt")
|
|
||||||
|
|
||||||
# Skip if either is missing or duplicate
|
|
||||||
if not agent_name or not system_prompt:
|
|
||||||
logger.warning(
|
|
||||||
f"Skipping row due to missing agent_name or system_prompt: {row}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
if is_duplicate(system_prompt, published_prompts):
|
|
||||||
logger.info(
|
|
||||||
f"Skipping duplicate prompt for agent: {agent_name}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create the data payload for the marketplace
|
|
||||||
prompt_data = {
|
|
||||||
"name": f"{agent_name} - System Prompt",
|
|
||||||
"prompt": system_prompt,
|
|
||||||
"description": f"System prompt for agent {agent_name}.",
|
|
||||||
"useCases": extract_use_cases(system_prompt),
|
|
||||||
"tags": "agent, system-prompt",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Publish to the marketplace
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
SWARMS_API_URL,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(prompt_data),
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info(
|
|
||||||
f"Successfully published prompt for agent: {agent_name}"
|
|
||||||
)
|
|
||||||
published_prompts.add(system_prompt)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish prompt for agent: {agent_name}. Response: {response.text}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Exception occurred while publishing prompt for agent: {agent_name}. Error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_duplicate(new_prompt, published_prompts):
|
|
||||||
"""Check if the prompt is a duplicate using semantic similarity."""
|
|
||||||
for prompt in published_prompts:
|
|
||||||
similarity = SequenceMatcher(None, new_prompt, prompt).ratio()
|
|
||||||
if (
|
|
||||||
similarity > 0.9
|
|
||||||
): # Threshold for considering prompts as duplicates
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def extract_use_cases(prompt):
|
|
||||||
"""Extract use cases from the prompt by chunking it into meaningful segments."""
|
|
||||||
# This is a simple placeholder; you can use a more advanced method to extract use cases
|
|
||||||
chunks = [prompt[i : i + 50] for i in range(0, len(prompt), 50)]
|
|
||||||
return [
|
|
||||||
{"title": f"Use case {idx+1}", "description": chunk}
|
|
||||||
for idx, chunk in enumerate(chunks)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Main execution
|
|
||||||
fetch_and_publish_prompts()
|
|
@ -0,0 +1,433 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import networkx as nx
|
||||||
|
import websockets
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
TREND_AGENT_PROMPT = """You are a specialized blockchain trend analysis agent. Your role:
|
||||||
|
1. Analyze transaction patterns in Solana blockchain data
|
||||||
|
2. Identify volume trends, price movements, and temporal patterns
|
||||||
|
3. Focus on whale movements and their market impact
|
||||||
|
4. Format findings in clear, structured JSON
|
||||||
|
5. Include confidence scores for each insight
|
||||||
|
6. Flag unusual patterns or anomalies
|
||||||
|
7. Provide historical context for significant movements
|
||||||
|
|
||||||
|
Output format:
|
||||||
|
{
|
||||||
|
"trends": [
|
||||||
|
{"pattern": str, "confidence": float, "impact": str}
|
||||||
|
],
|
||||||
|
"whale_activity": {...},
|
||||||
|
"temporal_analysis": {...}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
RISK_AGENT_PROMPT = """You are a blockchain risk assessment specialist. Your tasks:
|
||||||
|
1. Identify suspicious transaction patterns
|
||||||
|
2. Monitor for known exploit signatures
|
||||||
|
3. Assess wallet clustering and relationship patterns
|
||||||
|
4. Evaluate transaction velocity and size anomalies
|
||||||
|
5. Check for bridge-related risks
|
||||||
|
6. Monitor smart contract interactions
|
||||||
|
7. Flag potential wash trading
|
||||||
|
|
||||||
|
Output format:
|
||||||
|
{
|
||||||
|
"risk_score": float,
|
||||||
|
"flags": [...],
|
||||||
|
"recommendations": [...]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
SUMMARY_AGENT_PROMPT = """You are a blockchain data synthesis expert. Your responsibilities:
|
||||||
|
1. Combine insights from trend and risk analyses
|
||||||
|
2. Prioritize actionable intelligence
|
||||||
|
3. Highlight critical patterns
|
||||||
|
4. Generate executive summaries
|
||||||
|
5. Provide market context
|
||||||
|
6. Make predictions with confidence intervals
|
||||||
|
7. Suggest trading strategies based on data
|
||||||
|
|
||||||
|
Output format:
|
||||||
|
{
|
||||||
|
"key_insights": [...],
|
||||||
|
"market_impact": str,
|
||||||
|
"recommendations": {...}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Transaction:
|
||||||
|
signature: str
|
||||||
|
timestamp: datetime
|
||||||
|
amount: float
|
||||||
|
from_address: str
|
||||||
|
to_address: str
|
||||||
|
|
||||||
|
|
||||||
|
class SolanaRPC:
|
||||||
|
def __init__(
|
||||||
|
self, endpoint="https://api.mainnet-beta.solana.com"
|
||||||
|
):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def get_signatures(self, address: str) -> List[Dict]:
|
||||||
|
if not self.session:
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "getSignaturesForAddress",
|
||||||
|
"params": [address, {"limit": 100}],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with self.session.post(
|
||||||
|
self.endpoint, json=payload
|
||||||
|
) as response:
|
||||||
|
result = await response.json()
|
||||||
|
return result.get("result", [])
|
||||||
|
|
||||||
|
async def get_transaction(self, signature: str) -> Dict:
|
||||||
|
payload = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "getTransaction",
|
||||||
|
"params": [
|
||||||
|
signature,
|
||||||
|
{
|
||||||
|
"encoding": "json",
|
||||||
|
"maxSupportedTransactionVersion": 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with self.session.post(
|
||||||
|
self.endpoint, json=payload
|
||||||
|
) as response:
|
||||||
|
result = await response.json()
|
||||||
|
return result.get("result", {})
|
||||||
|
|
||||||
|
|
||||||
|
class AlertSystem:
|
||||||
|
def __init__(self, email: str, threshold: float = 1000.0):
|
||||||
|
self.email = email
|
||||||
|
self.threshold = threshold
|
||||||
|
self.smtp_server = "smtp.gmail.com"
|
||||||
|
self.smtp_port = 587
|
||||||
|
|
||||||
|
async def check_and_alert(
|
||||||
|
self, transaction: Transaction, risk_score: float
|
||||||
|
):
|
||||||
|
if transaction.amount > self.threshold or risk_score > 0.8:
|
||||||
|
await self.send_alert(transaction, risk_score)
|
||||||
|
|
||||||
|
async def send_alert(
|
||||||
|
self, transaction: Transaction, risk_score: float
|
||||||
|
):
|
||||||
|
# msg = MIMEText(
|
||||||
|
# f"High-risk transaction detected:\n"
|
||||||
|
# f"Amount: {transaction.amount} SOL\n"
|
||||||
|
# f"Risk Score: {risk_score}\n"
|
||||||
|
# f"Signature: {transaction.signature}"
|
||||||
|
# )
|
||||||
|
logger.info(
|
||||||
|
f"Alert sent for transaction {transaction.signature}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletClusterAnalyzer:
|
||||||
|
def __init__(self):
|
||||||
|
self.graph = nx.Graph()
|
||||||
|
self.known_wallets: Set[str] = set()
|
||||||
|
|
||||||
|
def update_graph(self, transaction: Transaction):
|
||||||
|
self.graph.add_edge(
|
||||||
|
transaction.from_address,
|
||||||
|
transaction.to_address,
|
||||||
|
weight=transaction.amount,
|
||||||
|
)
|
||||||
|
self.known_wallets.add(transaction.from_address)
|
||||||
|
self.known_wallets.add(transaction.to_address)
|
||||||
|
|
||||||
|
def identify_clusters(self) -> Dict:
|
||||||
|
communities = nx.community.greedy_modularity_communities(
|
||||||
|
self.graph
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"clusters": [list(c) for c in communities],
|
||||||
|
"central_wallets": [
|
||||||
|
wallet
|
||||||
|
for wallet in self.known_wallets
|
||||||
|
if self.graph.degree[wallet] > 5
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionVisualizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.transaction_history = []
|
||||||
|
|
||||||
|
def add_transaction(self, transaction: Transaction):
|
||||||
|
self.transaction_history.append(asdict(transaction))
|
||||||
|
|
||||||
|
def generate_volume_chart(self) -> str:
|
||||||
|
volumes = [tx["amount"] for tx in self.transaction_history]
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.plot(volumes)
|
||||||
|
plt.title("Transaction Volume Over Time")
|
||||||
|
plt.savefig("volume_chart.png")
|
||||||
|
return "volume_chart.png"
|
||||||
|
|
||||||
|
def generate_network_graph(
|
||||||
|
self, wallet_analyzer: WalletClusterAnalyzer
|
||||||
|
) -> str:
|
||||||
|
plt.figure(figsize=(15, 15))
|
||||||
|
pos = nx.spring_layout(wallet_analyzer.graph)
|
||||||
|
nx.draw(
|
||||||
|
wallet_analyzer.graph,
|
||||||
|
pos,
|
||||||
|
node_size=1000,
|
||||||
|
node_color="lightblue",
|
||||||
|
with_labels=True,
|
||||||
|
)
|
||||||
|
plt.savefig("network_graph.png")
|
||||||
|
return "network_graph.png"
|
||||||
|
|
||||||
|
|
||||||
|
class SolanaMultiAgentAnalyzer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_amount: float = 50.0,
|
||||||
|
websocket_url: str = "wss://api.mainnet-beta.solana.com",
|
||||||
|
alert_email: str = None,
|
||||||
|
):
|
||||||
|
self.rpc = SolanaRPC()
|
||||||
|
self.websocket_url = websocket_url
|
||||||
|
self.min_amount = min_amount
|
||||||
|
self.transactions = []
|
||||||
|
|
||||||
|
self.wallet_analyzer = WalletClusterAnalyzer()
|
||||||
|
self.visualizer = TransactionVisualizer()
|
||||||
|
self.alert_system = (
|
||||||
|
AlertSystem(alert_email) if alert_email else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.trend_agent = Agent(
|
||||||
|
agent_name="trend-analyzer",
|
||||||
|
system_prompt=TREND_AGENT_PROMPT,
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.risk_agent = Agent(
|
||||||
|
agent_name="risk-analyzer",
|
||||||
|
system_prompt=RISK_AGENT_PROMPT,
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.summary_agent = Agent(
|
||||||
|
agent_name="summary-agent",
|
||||||
|
system_prompt=SUMMARY_AGENT_PROMPT,
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
"solana_analysis.log", rotation="500 MB", level="INFO"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start_websocket_stream(self):
|
||||||
|
async with websockets.connect(
|
||||||
|
self.websocket_url
|
||||||
|
) as websocket:
|
||||||
|
subscribe_message = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "programSubscribe",
|
||||||
|
"params": [
|
||||||
|
"11111111111111111111111111111111",
|
||||||
|
{"encoding": "json", "commitment": "confirmed"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
await websocket.send(json.dumps(subscribe_message))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = await websocket.recv()
|
||||||
|
transaction = await self.parse_websocket_message(
|
||||||
|
msg
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
transaction
|
||||||
|
and transaction.amount >= self.min_amount
|
||||||
|
):
|
||||||
|
await self.process_transaction(transaction)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Websocket error: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
async def parse_websocket_message(
|
||||||
|
self, msg: str
|
||||||
|
) -> Optional[Transaction]:
|
||||||
|
try:
|
||||||
|
data = json.loads(msg)
|
||||||
|
if "params" in data and "result" in data["params"]:
|
||||||
|
tx_data = data["params"]["result"]
|
||||||
|
return Transaction(
|
||||||
|
signature=tx_data["signature"],
|
||||||
|
timestamp=datetime.fromtimestamp(
|
||||||
|
tx_data["blockTime"]
|
||||||
|
),
|
||||||
|
amount=float(
|
||||||
|
tx_data["meta"]["postBalances"][0]
|
||||||
|
- tx_data["meta"]["preBalances"][0]
|
||||||
|
)
|
||||||
|
/ 1e9,
|
||||||
|
from_address=tx_data["transaction"]["message"][
|
||||||
|
"accountKeys"
|
||||||
|
][0],
|
||||||
|
to_address=tx_data["transaction"]["message"][
|
||||||
|
"accountKeys"
|
||||||
|
][1],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing websocket message: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_transaction(self, transaction: Transaction):
|
||||||
|
self.wallet_analyzer.update_graph(transaction)
|
||||||
|
self.visualizer.add_transaction(transaction)
|
||||||
|
|
||||||
|
risk_analysis = await self.risk_agent.run(
|
||||||
|
f"Analyze risk for transaction: {json.dumps(asdict(transaction))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.alert_system:
|
||||||
|
await self.alert_system.check_and_alert(
|
||||||
|
transaction, risk_analysis.get("risk_score", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_transactions(self) -> List[Transaction]:
|
||||||
|
try:
|
||||||
|
signatures = await self.rpc.get_signatures(
|
||||||
|
"11111111111111111111111111111111"
|
||||||
|
)
|
||||||
|
transactions = []
|
||||||
|
|
||||||
|
for sig_info in signatures:
|
||||||
|
tx_data = await self.rpc.get_transaction(
|
||||||
|
sig_info["signature"]
|
||||||
|
)
|
||||||
|
if not tx_data or "meta" not in tx_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pre_balances = tx_data["meta"]["preBalances"]
|
||||||
|
post_balances = tx_data["meta"]["postBalances"]
|
||||||
|
amount = abs(pre_balances[0] - post_balances[0]) / 1e9
|
||||||
|
|
||||||
|
if amount >= self.min_amount:
|
||||||
|
tx = Transaction(
|
||||||
|
signature=sig_info["signature"],
|
||||||
|
timestamp=datetime.fromtimestamp(
|
||||||
|
tx_data["blockTime"]
|
||||||
|
),
|
||||||
|
amount=amount,
|
||||||
|
from_address=tx_data["transaction"][
|
||||||
|
"message"
|
||||||
|
]["accountKeys"][0],
|
||||||
|
to_address=tx_data["transaction"]["message"][
|
||||||
|
"accountKeys"
|
||||||
|
][1],
|
||||||
|
)
|
||||||
|
transactions.append(tx)
|
||||||
|
|
||||||
|
return transactions
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching transactions: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def analyze_transactions(
|
||||||
|
self, transactions: List[Transaction]
|
||||||
|
) -> Dict:
|
||||||
|
tx_data = [asdict(tx) for tx in transactions]
|
||||||
|
cluster_data = self.wallet_analyzer.identify_clusters()
|
||||||
|
|
||||||
|
trend_analysis = await self.trend_agent.run(
|
||||||
|
f"Analyze trends in: {json.dumps(tx_data)}"
|
||||||
|
)
|
||||||
|
print(trend_analysis)
|
||||||
|
|
||||||
|
risk_analysis = await self.risk_agent.run(
|
||||||
|
f"Analyze risks in: {json.dumps({'transactions': tx_data, 'clusters': cluster_data})}"
|
||||||
|
)
|
||||||
|
print(risk_analysis)
|
||||||
|
|
||||||
|
summary = await self.summary_agent.run(
|
||||||
|
f"Synthesize insights from: {trend_analysis}, {risk_analysis}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(summary)
|
||||||
|
|
||||||
|
volume_chart = self.visualizer.generate_volume_chart()
|
||||||
|
network_graph = self.visualizer.generate_network_graph(
|
||||||
|
self.wallet_analyzer
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"transactions": tx_data,
|
||||||
|
"trend_analysis": trend_analysis,
|
||||||
|
"risk_analysis": risk_analysis,
|
||||||
|
"cluster_analysis": cluster_data,
|
||||||
|
"summary": summary,
|
||||||
|
"visualizations": {
|
||||||
|
"volume_chart": volume_chart,
|
||||||
|
"network_graph": network_graph,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run_continuous_analysis(self):
|
||||||
|
logger.info("Starting continuous analysis")
|
||||||
|
asyncio.create_task(self.start_websocket_stream())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
transactions = await self.fetch_transactions()
|
||||||
|
if transactions:
|
||||||
|
analysis = await self.analyze_transactions(
|
||||||
|
transactions
|
||||||
|
)
|
||||||
|
timestamp = datetime.now().strftime(
|
||||||
|
"%Y%m%d_%H%M%S"
|
||||||
|
)
|
||||||
|
with open(f"analysis_{timestamp}.json", "w") as f:
|
||||||
|
json.dump(analysis, f, indent=2, default=str)
|
||||||
|
logger.info(
|
||||||
|
f"Analysis completed: analysis_{timestamp}.json"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in analysis loop: {e}")
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
|
||||||
|
# Add to __main__:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("Starting Solana analyzer...")
|
||||||
|
analyzer = SolanaMultiAgentAnalyzer(alert_email="your@email.com")
|
||||||
|
try:
|
||||||
|
asyncio.run(analyzer.run_continuous_analysis())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Critical error: {e}")
|
@ -1,29 +0,0 @@
|
|||||||
from typing import Dict, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
"""
|
|
||||||
Represents a message with timestamp and optional metadata.
|
|
||||||
|
|
||||||
Usage
|
|
||||||
--------------
|
|
||||||
mes = Message(
|
|
||||||
sender = "Kye",
|
|
||||||
content = "message"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(mes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
timestamp: datetime = Field(default_factory=datetime.now)
|
|
||||||
sender: str
|
|
||||||
content: str
|
|
||||||
metadata: Optional[Dict[str, str]] = {}
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
"""
|
|
||||||
__repr__ means...
|
|
||||||
"""
|
|
||||||
return f"{self.timestamp} - {self.sender}: {self.content}"
|
|
@ -0,0 +1,15 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
# Literal of output types
|
||||||
|
# Literal of output types
|
||||||
|
OutputType = Literal[
|
||||||
|
"all",
|
||||||
|
"final",
|
||||||
|
"list",
|
||||||
|
"dict",
|
||||||
|
".json",
|
||||||
|
".md",
|
||||||
|
".txt",
|
||||||
|
".yaml",
|
||||||
|
".toml",
|
||||||
|
]
|
@ -1,511 +0,0 @@
|
|||||||
"""
|
|
||||||
Todo
|
|
||||||
- [ ] Test the new api feature
|
|
||||||
- [ ] Add the agent schema for every agent -- following OpenAI assistaants schema
|
|
||||||
- [ ] then add the swarm schema for the swarm url: /v1/swarms/{swarm_name}/agents/{agent_id}
|
|
||||||
- [ ] Add the agent schema for the agent url: /v1/swarms/{swarm_name}/agents/{agent_id}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import multiprocessing
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import tenacity
|
|
||||||
|
|
||||||
# from fastapi import FastAPI
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from swarms.structs.agent import Agent
|
|
||||||
from swarms.structs.base_swarm import BaseSwarm
|
|
||||||
from swarms.utils.loguru_logger import initialize_logger
|
|
||||||
|
|
||||||
logger = initialize_logger("swarm-network")
|
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models
|
|
||||||
class TaskRequest(BaseModel):
|
|
||||||
task: str
|
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models
|
|
||||||
class TaskResponse(BaseModel):
|
|
||||||
result: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentInfo(BaseModel):
|
|
||||||
agent_name: str
|
|
||||||
agent_description: str
|
|
||||||
|
|
||||||
|
|
||||||
class SwarmInfo(BaseModel):
|
|
||||||
swarm_name: str
|
|
||||||
swarm_description: str
|
|
||||||
agents: List[AgentInfo]
|
|
||||||
|
|
||||||
|
|
||||||
# Helper function to get the number of workers
|
|
||||||
def get_number_of_workers():
|
|
||||||
return multiprocessing.cpu_count()
|
|
||||||
|
|
||||||
|
|
||||||
# [TODO] Add the agent schema for every agent -- following OpenAI assistaants schema
|
|
||||||
class SwarmNetwork(BaseSwarm):
|
|
||||||
"""
|
|
||||||
SwarmNetwork class
|
|
||||||
|
|
||||||
The SwarmNetwork class is responsible for managing the agents pool
|
|
||||||
and the task queue. It also monitors the health of the agents and
|
|
||||||
scales the pool up or down based on the number of pending tasks
|
|
||||||
and the current load of the agents.
|
|
||||||
|
|
||||||
For example, if the number of pending tasks is greater than the
|
|
||||||
number of agents in the pool, the SwarmNetwork will scale up the
|
|
||||||
pool by adding new agents. If the number of pending tasks is less
|
|
||||||
than the number of agents in the pool, the SwarmNetwork will scale
|
|
||||||
down the pool by removing agents.
|
|
||||||
|
|
||||||
The SwarmNetwork class also provides a simple API for interacting
|
|
||||||
with the agents pool. The API is implemented using the Flask
|
|
||||||
framework and is enabled by default. The API can be disabled by
|
|
||||||
setting the `api_enabled` parameter to False.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Agent pool management
|
|
||||||
- Task queue management
|
|
||||||
- Agent health monitoring
|
|
||||||
- Agent pool scaling
|
|
||||||
- Simple API for interacting with the agent pool
|
|
||||||
- Simple API for interacting with the task queue
|
|
||||||
- Simple API for interacting with the agent health monitor
|
|
||||||
- Simple API for interacting with the agent pool scaler
|
|
||||||
- Create APIs for each agent in the pool (optional)
|
|
||||||
- Run each agent on it's own thread
|
|
||||||
- Run each agent on it's own process
|
|
||||||
- Run each agent on it's own container
|
|
||||||
- Run each agent on it's own machine
|
|
||||||
- Run each agent on it's own cluster
|
|
||||||
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task_queue (queue.Queue): A queue for storing tasks.
|
|
||||||
idle_threshold (float): The idle threshold for the agents.
|
|
||||||
busy_threshold (float): The busy threshold for the agents.
|
|
||||||
agents (List[Agent]): A list of agents in the pool.
|
|
||||||
api_enabled (bool): A flag to enable/disable the API.
|
|
||||||
logging_enabled (bool): A flag to enable/disable logging.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from swarms.structs.agent import Agent
|
|
||||||
>>> from swarms.structs.swarm_net import SwarmNetwork
|
|
||||||
>>> agent = Agent()
|
|
||||||
>>> swarm = SwarmNetwork(agents=[agent])
|
|
||||||
>>> swarm.add_task("task")
|
|
||||||
>>> swarm.run()
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str = None,
|
|
||||||
description: str = None,
|
|
||||||
agents: List[Agent] = None,
|
|
||||||
idle_threshold: float = 0.2,
|
|
||||||
busy_threshold: float = 0.7,
|
|
||||||
api_enabled: Optional[bool] = False,
|
|
||||||
logging_enabled: Optional[bool] = False,
|
|
||||||
api_on: Optional[bool] = False,
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
port: int = 8000,
|
|
||||||
swarm_callable: Optional[callable] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(agents=agents, *args, **kwargs)
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.agents = agents
|
|
||||||
self.task_queue = queue.Queue()
|
|
||||||
self.idle_threshold = idle_threshold
|
|
||||||
self.busy_threshold = busy_threshold
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.api_enabled = api_enabled
|
|
||||||
self.logging_enabled = logging_enabled
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.swarm_callable = swarm_callable
|
|
||||||
|
|
||||||
# Ensure that the agents list is not empty
|
|
||||||
if not agents:
|
|
||||||
raise ValueError("The agents list cannot be empty")
|
|
||||||
|
|
||||||
# Create a dictionary of agents for easy access
|
|
||||||
self.agent_dict = {agent.id: agent for agent in agents}
|
|
||||||
|
|
||||||
# # Create the FastAPI instance
|
|
||||||
# if api_on is True:
|
|
||||||
# logger.info("Creating FastAPI instance")
|
|
||||||
# self.app = FastAPI(debug=True, *args, **kwargs)
|
|
||||||
|
|
||||||
# self.app.add_middleware(
|
|
||||||
# CORSMiddleware,
|
|
||||||
# allow_origins=["*"],
|
|
||||||
# allow_credentials=True,
|
|
||||||
# allow_methods=["*"],
|
|
||||||
# allow_headers=["*"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# logger.info("Routes set for creation")
|
|
||||||
# self._create_routes()
|
|
||||||
|
|
||||||
def add_task(self, task):
|
|
||||||
"""Add task to the task queue
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (_type_): _description_
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from swarms.structs.agent import Agent
|
|
||||||
>>> from swarms.structs.swarm_net import SwarmNetwork
|
|
||||||
>>> agent = Agent()
|
|
||||||
>>> swarm = SwarmNetwork(agents=[agent])
|
|
||||||
>>> swarm.add_task("task")
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Adding task {task} to queue")
|
|
||||||
try:
|
|
||||||
self.task_queue.put(task)
|
|
||||||
self.logger.info(f"Task {task} added to queue")
|
|
||||||
except Exception as error:
|
|
||||||
print(
|
|
||||||
f"Error adding task to queue: {error} try again with"
|
|
||||||
" a new task"
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
async def async_add_task(self, task):
|
|
||||||
"""Add task to the task queue
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (_type_): _description_
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from swarms.structs.agent import Agent
|
|
||||||
>>> from swarms.structs.swarm_net import SwarmNetwork
|
|
||||||
>>> agent = Agent()
|
|
||||||
>>> swarm = SwarmNetwork(agents=[agent])
|
|
||||||
>>> swarm.add_task("task")
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.logger.info(
|
|
||||||
f"Adding task {task} to queue asynchronously"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# Add task to queue asynchronously with asyncio
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None, self.task_queue.put, task
|
|
||||||
)
|
|
||||||
self.logger.info(f"Task {task} added to queue")
|
|
||||||
except Exception as error:
|
|
||||||
print(
|
|
||||||
f"Error adding task to queue: {error} try again with"
|
|
||||||
" a new task"
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
# def _create_routes(self) -> None:
|
|
||||||
# """
|
|
||||||
# Creates the routes for the API.
|
|
||||||
# """
|
|
||||||
# # Extensive logginbg
|
|
||||||
# logger.info("Creating routes for the API")
|
|
||||||
|
|
||||||
# # Routes available
|
|
||||||
# logger.info(
|
|
||||||
# "Routes available: /v1/swarms, /v1/health, /v1/swarms/{swarm_name}/agents/{agent_id}, /v1/swarms/{swarm_name}/run"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @self.app.get("/v1/swarms", response_model=SwarmInfo)
|
|
||||||
# async def get_swarms() -> SwarmInfo:
|
|
||||||
# try:
|
|
||||||
# logger.info("Getting swarm information")
|
|
||||||
# return SwarmInfo(
|
|
||||||
# swarm_name=self.swarm_name,
|
|
||||||
# swarm_description=self.swarm_description,
|
|
||||||
# agents=[
|
|
||||||
# AgentInfo(
|
|
||||||
# agent_name=agent.agent_name,
|
|
||||||
# agent_description=agent.agent_description,
|
|
||||||
# )
|
|
||||||
# for agent in self.agents
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error getting swarm information: {str(e)}")
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=500, detail="Internal Server Error"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @self.app.get("/v1/health")
|
|
||||||
# async def get_health() -> Dict[str, str]:
|
|
||||||
# try:
|
|
||||||
# logger.info("Checking health status")
|
|
||||||
# return {"status": "healthy"}
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error checking health status: {str(e)}")
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=500, detail="Internal Server Error"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @self.app.get(f"/v1/swarms/{self.swarm_name}/agents/{{agent_id}}")
|
|
||||||
# async def get_agent_info(agent_id: str) -> AgentInfo:
|
|
||||||
# try:
|
|
||||||
# logger.info(f"Getting information for agent {agent_id}")
|
|
||||||
# agent = self.agent_dict.get(agent_id)
|
|
||||||
# if not agent:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=404, detail="Agent not found"
|
|
||||||
# )
|
|
||||||
# return AgentInfo(
|
|
||||||
# agent_name=agent.agent_name,
|
|
||||||
# agent_description=agent.agent_description,
|
|
||||||
# )
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error getting agent information: {str(e)}")
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=500, detail="Internal Server Error"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @self.app.post(
|
|
||||||
# f"/v1/swarms/{self.swarm_name}/agents/{{agent_id}}/run",
|
|
||||||
# response_model=TaskResponse,
|
|
||||||
# )
|
|
||||||
# async def run_agent_task(
|
|
||||||
# task_request: TaskRequest,
|
|
||||||
# ) -> TaskResponse:
|
|
||||||
# try:
|
|
||||||
# logger.info("Running agent task")
|
|
||||||
# # Assuming only one agent in the swarm for this example
|
|
||||||
# agent = self.agents[0]
|
|
||||||
# logger.info(f"Running agent task: {task_request.task}")
|
|
||||||
# result = agent.run(task_request.task)
|
|
||||||
# return TaskResponse(result=result)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error running agent task: {str(e)}")
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=500, detail="Internal Server Error"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def get_app(self) -> FastAPI:
|
|
||||||
# """
|
|
||||||
# Returns the FastAPI instance.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# FastAPI: The FastAPI instance.
|
|
||||||
# """
|
|
||||||
# return self.app
|
|
||||||
|
|
||||||
def run_single_agent(
|
|
||||||
self, agent_id, task: Optional[str], *args, **kwargs
|
|
||||||
):
|
|
||||||
"""Run agent the task on the agent id
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id (_type_): _description_
|
|
||||||
task (str, optional): _description_. Defaults to None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: _description_
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_type_: _description_
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Running task {task} on agent {agent_id}")
|
|
||||||
try:
|
|
||||||
for agent in self.agents:
|
|
||||||
if agent.id == agent_id:
|
|
||||||
out = agent.run(task, *args, **kwargs)
|
|
||||||
return out
|
|
||||||
except Exception as error:
|
|
||||||
self.logger.error(f"Error running task on agent: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def run_many_agents(
|
|
||||||
self, task: Optional[str] = None, *args, **kwargs
|
|
||||||
) -> List:
|
|
||||||
"""Run the task on all agents
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str, optional): _description_. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List: _description_
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Running task {task} on all agents")
|
|
||||||
try:
|
|
||||||
return [
|
|
||||||
agent.run(task, *args, **kwargs)
|
|
||||||
for agent in self.agents
|
|
||||||
]
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(f"Error running task on agents: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def list_agents(self):
|
|
||||||
"""List all agents."""
|
|
||||||
self.logger.info("[Listing all active agents]")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Assuming self.agents is a list of agent objects
|
|
||||||
for agent in self.agents:
|
|
||||||
self.logger.info(
|
|
||||||
f"[Agent] [ID: {agent.id}] [Name:"
|
|
||||||
f" {agent.agent_name}] [Description:"
|
|
||||||
f" {agent.agent_description}] [Status: Running]"
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
self.logger.error(f"Error listing agents: {error}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_agent(self, agent_id):
|
|
||||||
"""Get agent by id
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id (_type_): _description_
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_type_: _description_
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Getting agent {agent_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for agent in self.agents:
|
|
||||||
if agent.id == agent_id:
|
|
||||||
return agent
|
|
||||||
raise ValueError(f"No agent found with ID {agent_id}")
|
|
||||||
except Exception as error:
|
|
||||||
self.logger.error(f"Error getting agent: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def add_agent(self, agent: Agent):
|
|
||||||
"""Add agent to the agent pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent (_type_): _description_
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Adding agent {agent} to pool")
|
|
||||||
try:
|
|
||||||
self.agents.append(agent)
|
|
||||||
except Exception as error:
|
|
||||||
print(f"Error adding agent to pool: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def remove_agent(self, agent_id):
|
|
||||||
"""Remove agent from the agent pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id (_type_): _description_
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Removing agent {agent_id} from pool")
|
|
||||||
try:
|
|
||||||
for agent in self.agents:
|
|
||||||
if agent.id == agent_id:
|
|
||||||
self.agents.remove(agent)
|
|
||||||
return
|
|
||||||
raise ValueError(f"No agent found with ID {agent_id}")
|
|
||||||
except Exception as error:
|
|
||||||
print(f"Error removing agent from pool: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
async def async_remove_agent(self, agent_id):
|
|
||||||
"""Remove agent from the agent pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id (_type_): _description_
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Removing agent {agent_id} from pool")
|
|
||||||
try:
|
|
||||||
# Remove agent from pool asynchronously with asyncio
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None, self.remove_agent, agent_id
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
print(f"Error removing agent from pool: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def scale_up(self, num_agents: int = 1):
|
|
||||||
"""Scale up the agent pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_agents (int, optional): _description_. Defaults to 1.
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Scaling up agent pool by {num_agents}")
|
|
||||||
try:
|
|
||||||
for _ in range(num_agents):
|
|
||||||
self.agents.append(Agent())
|
|
||||||
except Exception as error:
|
|
||||||
print(f"Error scaling up agent pool: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def scale_down(self, num_agents: int = 1):
|
|
||||||
"""Scale down the agent pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_agents (int, optional): _description_. Defaults to 1.
|
|
||||||
"""
|
|
||||||
for _ in range(num_agents):
|
|
||||||
self.agents.pop()
|
|
||||||
|
|
||||||
@tenacity.retry(
|
|
||||||
wait=tenacity.wait_fixed(1),
|
|
||||||
stop=tenacity.stop_after_attempt(3),
|
|
||||||
retry=tenacity.retry_if_exception_type(Exception),
|
|
||||||
)
|
|
||||||
def run(self, *args, **kwargs):
|
|
||||||
"""run the swarm network"""
|
|
||||||
app = self.get_app()
|
|
||||||
|
|
||||||
try:
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Running the swarm network with {len(self.agents)} on {self.host}:{self.port}"
|
|
||||||
)
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=self.host,
|
|
||||||
port=self.port,
|
|
||||||
# workers=get_number_of_workers(),
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return app
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(f"Error running the swarm network: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
|
|
||||||
# # # Example usage
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
|
|
||||||
# agent1 = Agent(
|
|
||||||
# agent_name="Covid-19-Chat",
|
|
||||||
# agent_description="This agent provides information about COVID-19 symptoms.",
|
|
||||||
# llm=OpenAIChat(),
|
|
||||||
# max_loops="auto",
|
|
||||||
# autosave=True,
|
|
||||||
# verbose=True,
|
|
||||||
# stopping_condition="finish",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# agents = [agent1] # Add more agents as needed
|
|
||||||
# swarm_name = "HealthSwarm"
|
|
||||||
# swarm_description = (
|
|
||||||
# "A swarm of agents providing health-related information."
|
|
||||||
# )
|
|
||||||
|
|
||||||
# agent_api = SwarmNetwork(swarm_name, swarm_description, agents)
|
|
||||||
# agent_api.run()
|
|
@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
Package installation utility that checks for package existence and installs if needed.
|
||||||
|
Supports both pip and conda package managers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
import pkg_resources
|
||||||
|
|
||||||
|
|
||||||
|
logger = initialize_logger("autocheckpackages")
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_install_package(
|
||||||
|
package_name: str,
|
||||||
|
package_manager: Literal["pip", "conda"] = "pip",
|
||||||
|
version: Optional[str] = None,
|
||||||
|
upgrade: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a package is installed and install it if not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
package_name: Name of the package to check/install
|
||||||
|
package_manager: Package manager to use ('pip' or 'conda')
|
||||||
|
version: Specific version to install (optional)
|
||||||
|
upgrade: Whether to upgrade the package if it exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if package is available after check/install, False if installation failed
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If invalid package manager is specified
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if package exists
|
||||||
|
if package_manager == "pip":
|
||||||
|
try:
|
||||||
|
pkg_resources.get_distribution(package_name)
|
||||||
|
if not upgrade:
|
||||||
|
logger.info(
|
||||||
|
f"Package {package_name} is already installed"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except pkg_resources.DistributionNotFound:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Construct installation command
|
||||||
|
cmd = [sys.executable, "-m", "pip", "install"]
|
||||||
|
if upgrade:
|
||||||
|
cmd.append("--upgrade")
|
||||||
|
|
||||||
|
if version:
|
||||||
|
cmd.append(f"{package_name}=={version}")
|
||||||
|
else:
|
||||||
|
cmd.append(package_name)
|
||||||
|
|
||||||
|
elif package_manager == "conda":
|
||||||
|
# Check if conda is available
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["conda", "--version"],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
logger.error(
|
||||||
|
"Conda is not available. Please install conda first."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Construct conda command
|
||||||
|
cmd = ["conda", "install", "-y"]
|
||||||
|
if version:
|
||||||
|
cmd.append(f"{package_name}={version}")
|
||||||
|
else:
|
||||||
|
cmd.append(package_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid package manager: {package_manager}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run installation
|
||||||
|
logger.info(f"Installing {package_name}...")
|
||||||
|
subprocess.run(
|
||||||
|
cmd, check=True, capture_output=True, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
try:
|
||||||
|
importlib.import_module(package_name)
|
||||||
|
logger.info(f"Successfully installed {package_name}")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
f"Package {package_name} was installed but cannot be imported"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Failed to install {package_name}: {e.stderr}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error while installing {package_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def auto_check_and_download_package(
|
||||||
|
packages: Union[str, list[str]],
|
||||||
|
package_manager: Literal["pip", "conda"] = "pip",
|
||||||
|
upgrade: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure multiple packages are installed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
packages: Single package name or list of package names
|
||||||
|
package_manager: Package manager to use ('pip' or 'conda')
|
||||||
|
upgrade: Whether to upgrade existing packages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all packages are available, False if any installation failed
|
||||||
|
"""
|
||||||
|
if isinstance(packages, str):
|
||||||
|
packages = [packages]
|
||||||
|
|
||||||
|
success = True
|
||||||
|
for package in packages:
|
||||||
|
if ":" in package:
|
||||||
|
name, version = package.split(":")
|
||||||
|
if not check_and_install_package(
|
||||||
|
name, package_manager, version, upgrade
|
||||||
|
):
|
||||||
|
success = False
|
||||||
|
else:
|
||||||
|
if not check_and_install_package(
|
||||||
|
package, package_manager, upgrade=upgrade
|
||||||
|
):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
return success
|
@ -0,0 +1,263 @@
|
|||||||
|
"""
|
||||||
|
Lazy Package Loader
|
||||||
|
|
||||||
|
This module provides utilities for lazy loading Python packages to improve startup time
|
||||||
|
and reduce memory usage by only importing packages when they are actually used.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Type-safe lazy loading of packages
|
||||||
|
- Support for nested module imports
|
||||||
|
- Auto-completion support in IDEs
|
||||||
|
- Thread-safe implementation
|
||||||
|
- Comprehensive test coverage
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
import importlib
|
||||||
|
import functools
|
||||||
|
import threading
|
||||||
|
from importlib.util import find_spec
|
||||||
|
from swarms.utils.auto_download_check_packages import (
|
||||||
|
auto_check_and_download_package,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
C = TypeVar("C")
|
||||||
|
|
||||||
|
|
||||||
|
class ImportError(Exception):
|
||||||
|
"""Raised when a lazy import fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoader:
|
||||||
|
"""
|
||||||
|
A thread-safe lazy loader for Python packages that only imports them when accessed.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_module_name (str): The name of the module to be lazily loaded
|
||||||
|
_module (Optional[ModuleType]): The cached module instance once loaded
|
||||||
|
_lock (threading.Lock): Thread lock for safe concurrent access
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> np = LazyLoader('numpy')
|
||||||
|
>>> # numpy is not imported yet
|
||||||
|
>>> result = np.array([1, 2, 3])
|
||||||
|
>>> # numpy is imported only when first used
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, module_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the lazy loader with a module name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: The fully qualified name of the module to lazily load
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the module cannot be found in sys.path
|
||||||
|
"""
|
||||||
|
self._module_name = module_name
|
||||||
|
self._module: Optional[ModuleType] = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
auto_check_and_download_package(
|
||||||
|
module_name, package_manager="pip"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify module exists without importing it
|
||||||
|
if find_spec(module_name) is None:
|
||||||
|
raise ImportError(
|
||||||
|
f"Module '{module_name}' not found in sys.path"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_module(self) -> ModuleType:
|
||||||
|
"""
|
||||||
|
Thread-safe module loading.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModuleType: The loaded module
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If module import fails
|
||||||
|
"""
|
||||||
|
if self._module is None:
|
||||||
|
with self._lock:
|
||||||
|
# Double-check pattern
|
||||||
|
if self._module is None:
|
||||||
|
try:
|
||||||
|
self._module = importlib.import_module(
|
||||||
|
self._module_name
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Failed to import '{self._module_name}': {str(e)}"
|
||||||
|
)
|
||||||
|
return cast(ModuleType, self._module)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""
|
||||||
|
Intercepts attribute access to load the module if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The attribute name being accessed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The requested attribute from the loaded module
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: If the attribute doesn't exist in the module
|
||||||
|
"""
|
||||||
|
module = self._load_module()
|
||||||
|
try:
|
||||||
|
return getattr(module, name)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Module '{self._module_name}' has no attribute '{name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __dir__(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns list of attributes for autocomplete support.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Available attributes in the module
|
||||||
|
"""
|
||||||
|
return dir(self._load_module())
|
||||||
|
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the module has been loaded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if module is loaded, False otherwise
|
||||||
|
"""
|
||||||
|
return self._module is not None
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoaderMetaclass(type):
|
||||||
|
"""Metaclass to handle lazy loading behavior"""
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if hasattr(cls, "_lazy_loader"):
|
||||||
|
return super().__call__(*args, **kwargs)
|
||||||
|
return super().__call__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyClassLoader:
|
||||||
|
"""
|
||||||
|
A descriptor that creates the actual class only when accessed,
|
||||||
|
with proper inheritance support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, class_name: str, bases: tuple, namespace: Dict[str, Any]
|
||||||
|
):
|
||||||
|
self.class_name = class_name
|
||||||
|
self.bases = bases
|
||||||
|
self.namespace = namespace
|
||||||
|
self._real_class: Optional[Type] = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def _create_class(self) -> Type:
|
||||||
|
"""Creates the actual class if it hasn't been created yet."""
|
||||||
|
if self._real_class is None:
|
||||||
|
with self._lock:
|
||||||
|
if self._real_class is None:
|
||||||
|
# Update namespace to include metaclass
|
||||||
|
namespace = dict(self.namespace)
|
||||||
|
namespace["__metaclass__"] = LazyLoaderMetaclass
|
||||||
|
|
||||||
|
# Create the class with metaclass
|
||||||
|
new_class = LazyLoaderMetaclass(
|
||||||
|
self.class_name, self.bases, namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store reference to this loader
|
||||||
|
new_class._lazy_loader = self
|
||||||
|
self._real_class = new_class
|
||||||
|
|
||||||
|
return cast(Type, self._real_class)
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Creates an instance of the lazy loaded class."""
|
||||||
|
real_class = self._create_class()
|
||||||
|
# Use the metaclass __call__ method
|
||||||
|
return real_class(*args, **kwargs)
|
||||||
|
|
||||||
|
def __instancecheck__(self, instance: Any) -> bool:
|
||||||
|
"""Support for isinstance() checks"""
|
||||||
|
real_class = self._create_class()
|
||||||
|
return isinstance(instance, real_class)
|
||||||
|
|
||||||
|
def __subclasscheck__(self, subclass: Type) -> bool:
|
||||||
|
"""Support for issubclass() checks"""
|
||||||
|
real_class = self._create_class()
|
||||||
|
return issubclass(subclass, real_class)
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_import(*names: str) -> Dict[str, LazyLoader]:
|
||||||
|
"""
|
||||||
|
Create multiple lazy loaders at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*names: Module names to create lazy loaders for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, LazyLoader]: Dictionary mapping module names to their lazy loaders
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> modules = lazy_import('numpy', 'pandas', 'matplotlib.pyplot')
|
||||||
|
>>> np = modules['numpy']
|
||||||
|
>>> pd = modules['pandas']
|
||||||
|
>>> plt = modules['matplotlib.pyplot']
|
||||||
|
"""
|
||||||
|
return {name.split(".")[-1]: LazyLoader(name) for name in names}
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_import_decorator(
|
||||||
|
target: Union[Callable[..., T], Type[C]]
|
||||||
|
) -> Union[Callable[..., T], Type[C], LazyClassLoader]:
|
||||||
|
"""
|
||||||
|
Enhanced decorator that supports both lazy imports and lazy class loading.
|
||||||
|
"""
|
||||||
|
if isinstance(target, type):
|
||||||
|
# Store the original class details
|
||||||
|
namespace = {
|
||||||
|
name: value
|
||||||
|
for name, value in target.__dict__.items()
|
||||||
|
if not name.startswith("__")
|
||||||
|
or name in ("__init__", "__new__")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create lazy loader
|
||||||
|
loader = LazyClassLoader(
|
||||||
|
target.__name__, target.__bases__, namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preserve class metadata
|
||||||
|
loader.__module__ = target.__module__
|
||||||
|
loader.__doc__ = target.__doc__
|
||||||
|
|
||||||
|
# Add reference to original class
|
||||||
|
loader._original_class = target
|
||||||
|
|
||||||
|
return loader
|
||||||
|
else:
|
||||||
|
# Handle function decoration
|
||||||
|
@functools.wraps(target)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||||
|
return target(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
@ -1,73 +0,0 @@
|
|||||||
import os
|
|
||||||
from loguru import logger
|
|
||||||
import pygame
|
|
||||||
import requests
|
|
||||||
import tempfile
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAITTS:
|
|
||||||
"""
|
|
||||||
A class to interact with OpenAI API and play the generated audio with improved streaming capabilities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), *args, **kwargs
|
|
||||||
)
|
|
||||||
pygame.init()
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self, task: str, play_sound: bool = True, *args, **kwargs
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Run a task with the OpenAI API and optionally play the generated audio with improved streaming.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task to be executed.
|
|
||||||
play_sound (bool): If True, play the generated audio.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = self.client.audio.speech.create(
|
|
||||||
model="tts-1",
|
|
||||||
voice="nova",
|
|
||||||
input=task,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
audio_url = response["url"]
|
|
||||||
logger.info("Task completed successfully.")
|
|
||||||
|
|
||||||
if play_sound:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
delete=False, suffix=".mp3"
|
|
||||||
) as tmp_file:
|
|
||||||
with requests.get(audio_url, stream=True) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
tmp_file.write(chunk)
|
|
||||||
pygame.mixer.music.load(tmp_file.name)
|
|
||||||
pygame.mixer.music.play()
|
|
||||||
while pygame.mixer.music.get_busy():
|
|
||||||
pygame.time.Clock().tick(10)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during task execution: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# client = OpenAITTS(api_key=os.getenv("OPENAI_API_KEY"))
|
|
||||||
# client.run("Hello world! This is a streaming test.", play_sound=True)
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_speech(
|
|
||||||
task: str, play_sound: bool = True, *args, **kwargs
|
|
||||||
):
|
|
||||||
out = OpenAITTS().run(
|
|
||||||
task, play_sound=play_sound, *args, **kwargs
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# print(text_to_speech(task="hello"))
|
|
@ -0,0 +1,42 @@
|
|||||||
|
from swarms.structs.tree_swarm import ForestSwarm, Tree, TreeAgent
|
||||||
|
|
||||||
|
|
||||||
|
agents_tree1 = [
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt="Stock Analysis Agent",
|
||||||
|
agent_name="Stock Analysis Agent",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt="Financial Planning Agent",
|
||||||
|
agent_name="Financial Planning Agent",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
agent_name="Retirement Strategy Agent",
|
||||||
|
system_prompt="Retirement Strategy Agent",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
agents_tree2 = [
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt="Tax Filing Agent",
|
||||||
|
agent_name="Tax Filing Agent",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt="Investment Strategy Agent",
|
||||||
|
agent_name="Investment Strategy Agent",
|
||||||
|
),
|
||||||
|
TreeAgent(
|
||||||
|
system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create trees
|
||||||
|
tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
|
||||||
|
tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
|
||||||
|
|
||||||
|
# Create the ForestSwarm
|
||||||
|
multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
|
||||||
|
|
||||||
|
# Run a task
|
||||||
|
task = "Our company is incorporated in delaware, how do we do our taxes for free?"
|
||||||
|
multi_agent_structure.run(task)
|
@ -0,0 +1,206 @@
|
|||||||
|
from swarms import Agent
|
||||||
|
from loguru import logger
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Configure loguru
|
||||||
|
logger.add("zkp_log.log", rotation="500 KB", retention="10 days", level="INFO")
|
||||||
|
|
||||||
|
|
||||||
|
class ProverAgent:
|
||||||
|
"""
|
||||||
|
Prover Agent for Zero Knowledge Proof.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Generate commitments based on a secret.
|
||||||
|
- Respond to challenges from the Verifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agent (Agent): Swarms agent instance.
|
||||||
|
p (int): The prime modulus.
|
||||||
|
g (int): The generator.
|
||||||
|
x (int): The Prover's secret.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p: int, g: int, secret: int):
|
||||||
|
self.p = p
|
||||||
|
self.g = g
|
||||||
|
self.x = secret # Prover's secret
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name="ProverAgent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loop=1,
|
||||||
|
interactive=False,
|
||||||
|
streaming_on=True,
|
||||||
|
system_prompt=(
|
||||||
|
"You are the Prover in a Zero Knowledge Proof (ZKP) system. "
|
||||||
|
"Your responsibilities are to generate commitments based on a secret value and "
|
||||||
|
"respond to challenges from the Verifier without revealing the secret. "
|
||||||
|
"Follow mathematical rules of modular arithmetic when performing computations."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info("Initialized ProverAgent with p={}, g={}, secret={}", p, g, secret)
|
||||||
|
|
||||||
|
def generate_commitment(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Generates a random commitment for the proof.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[int, int]: The random value (r) and the commitment (t).
|
||||||
|
"""
|
||||||
|
r = random.randint(1, self.p - 2)
|
||||||
|
task = (
|
||||||
|
f"Compute the commitment t = g^r % p for g={self.g}, r={r}, p={self.p}. "
|
||||||
|
"Return only the numerical value of t as an integer."
|
||||||
|
)
|
||||||
|
t = self.agent.run(task=task)
|
||||||
|
t_value = self._extract_integer(t, "commitment")
|
||||||
|
logger.info("Prover generated commitment: r={}, t={}", r, t_value)
|
||||||
|
return r, t_value
|
||||||
|
|
||||||
|
def _extract_integer(self, response: str, label: str) -> int:
|
||||||
|
"""
|
||||||
|
Extracts an integer from the LLM response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (str): The response from the agent.
|
||||||
|
label (str): A label for logging purposes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The extracted integer value.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use regex to find the first integer in the response
|
||||||
|
match = re.search(r"\b\d+\b", response)
|
||||||
|
if match:
|
||||||
|
value = int(match.group(0))
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No integer found in {label} response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to extract integer from {label} response: {response}")
|
||||||
|
raise ValueError(f"Invalid {label} response: {response}") from e
|
||||||
|
|
||||||
|
def respond_to_challenge(self, r: int, c: int) -> int:
|
||||||
|
"""
|
||||||
|
Computes the response to a challenge.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r (int): The random value used in the commitment.
|
||||||
|
c (int): The challenge issued by the Verifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The response (z).
|
||||||
|
"""
|
||||||
|
task = f"Compute the response z = (r + c * x) % (p-1) for r={r}, c={c}, x={self.x}, p={self.p}."
|
||||||
|
z = self.agent.run(task=task)
|
||||||
|
logger.info("Prover responded to challenge: z={}", z)
|
||||||
|
return int(z)
|
||||||
|
|
||||||
|
|
||||||
|
class VerifierAgent:
|
||||||
|
"""
|
||||||
|
Verifier Agent for Zero Knowledge Proof.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Issue challenges to the Prover.
|
||||||
|
- Verify the Prover's response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agent (Agent): Swarms agent instance.
|
||||||
|
p (int): The prime modulus.
|
||||||
|
g (int): The generator.
|
||||||
|
y (int): The public value from the Prover.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p: int, g: int, y: int):
|
||||||
|
self.p = p
|
||||||
|
self.g = g
|
||||||
|
self.y = y # Public value
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name="VerifierAgent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loop=1,
|
||||||
|
interactive=False,
|
||||||
|
streaming_on=True,
|
||||||
|
system_prompt=(
|
||||||
|
"You are the Verifier in a Zero Knowledge Proof (ZKP) system. "
|
||||||
|
"Your responsibilities are to issue random challenges and verify the Prover's response. "
|
||||||
|
"Use modular arithmetic to check if the proof satisfies g^z % p == (t * y^c) % p."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info("Initialized VerifierAgent with p={}, g={}, y={}", p, g, y)
|
||||||
|
|
||||||
|
def issue_challenge(self) -> int:
|
||||||
|
"""
|
||||||
|
Issues a random challenge to the Prover.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The challenge value (c).
|
||||||
|
"""
|
||||||
|
c = random.randint(1, 10)
|
||||||
|
logger.info("Verifier issued challenge: c={}", c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def verify_proof(self, t: int, z: int, c: int) -> bool:
|
||||||
|
"""
|
||||||
|
Verifies the Prover's response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (int): The commitment from the Prover.
|
||||||
|
z (int): The response from the Prover.
|
||||||
|
c (int): The challenge issued to the Prover.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the proof is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
task = f"Verify if g^z % p == (t * y^c) % p for g={self.g}, z={z}, p={self.p}, t={t}, y={self.y}, c={c}."
|
||||||
|
verification_result = self.agent.run(task=task)
|
||||||
|
is_valid = verification_result.strip().lower() == "true"
|
||||||
|
logger.info("Verifier checked proof: t={}, z={}, c={}, valid={}", t, z, c, is_valid)
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
|
||||||
|
class CoordinatorAgent:
|
||||||
|
"""
|
||||||
|
Coordinator for orchestrating the Zero Knowledge Proof protocol.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Initialize parameters.
|
||||||
|
- Facilitate interaction between Prover and Verifier agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p: int, g: int, secret: int):
|
||||||
|
self.p = p
|
||||||
|
self.g = g
|
||||||
|
self.prover = ProverAgent(p, g, secret)
|
||||||
|
y = pow(g, secret, p) # Public value
|
||||||
|
self.verifier = VerifierAgent(p, g, y)
|
||||||
|
logger.info("Coordinator initialized with p={}, g={}, secret={}", p, g, secret)
|
||||||
|
|
||||||
|
def orchestrate(self) -> bool:
|
||||||
|
"""
|
||||||
|
Orchestrates the Zero Knowledge Proof protocol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the proof is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
logger.info("Starting ZKP protocol orchestration.")
|
||||||
|
r, t = self.prover.generate_commitment()
|
||||||
|
c = self.verifier.issue_challenge()
|
||||||
|
z = self.prover.respond_to_challenge(r, c)
|
||||||
|
is_valid = self.verifier.verify_proof(t, z, c)
|
||||||
|
logger.info("ZKP protocol completed. Valid proof: {}", is_valid)
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example parameters
|
||||||
|
p = 23 # Prime number
|
||||||
|
g = 5 # Generator
|
||||||
|
secret = 7 # Prover's secret
|
||||||
|
|
||||||
|
# Initialize the Coordinator and run the protocol
|
||||||
|
coordinator = CoordinatorAgent(p, g, secret)
|
||||||
|
result = coordinator.orchestrate()
|
||||||
|
print(f"Zero Knowledge Proof Verification Result: {'Valid' if result else 'Invalid'}")
|
Loading…
Reference in new issue