parent
40a78c4d06
commit
3b99f3db17
@ -0,0 +1,112 @@
|
||||
import requests
|
||||
import json
|
||||
from time import sleep
|
||||
|
||||
BASE_URL = "http://api.swarms.ai:8000"
|
||||
|
||||
|
||||
def make_request(method, endpoint, data=None):
|
||||
"""Helper function to make requests with error handling"""
|
||||
url = f"{BASE_URL}{endpoint}"
|
||||
try:
|
||||
if method == "GET":
|
||||
response = requests.get(url)
|
||||
elif method == "POST":
|
||||
response = requests.post(url, json=data)
|
||||
elif method == "DELETE":
|
||||
response = requests.delete(url)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(
|
||||
f"Error making {method} request to {endpoint}: {str(e)}"
|
||||
)
|
||||
if hasattr(e.response, "text"):
|
||||
print(f"Response text: {e.response.text}")
|
||||
return None
|
||||
|
||||
|
||||
def create_agent():
|
||||
"""Create a test agent"""
|
||||
data = {
|
||||
"agent_name": "test_agent",
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": "You are a helpful assistant",
|
||||
"description": "Test agent",
|
||||
"temperature": 0.7,
|
||||
"max_loops": 1,
|
||||
"tags": ["test"],
|
||||
}
|
||||
return make_request("POST", "/v1/agent", data)
|
||||
|
||||
|
||||
def list_agents():
|
||||
"""List all agents"""
|
||||
return make_request("GET", "/v1/agents")
|
||||
|
||||
|
||||
def test_completion(agent_id):
|
||||
"""Test a completion with the agent"""
|
||||
data = {
|
||||
"prompt": "Say hello!",
|
||||
"agent_id": agent_id,
|
||||
"max_tokens": 100,
|
||||
}
|
||||
return make_request("POST", "/v1/agent/completions", data)
|
||||
|
||||
|
||||
def get_agent_metrics(agent_id):
|
||||
"""Get metrics for an agent"""
|
||||
return make_request("GET", f"/v1/agent/{agent_id}/metrics")
|
||||
|
||||
|
||||
def delete_agent(agent_id):
|
||||
"""Delete an agent"""
|
||||
return make_request("DELETE", f"/v1/agent/{agent_id}")
|
||||
|
||||
|
||||
def run_tests():
|
||||
print("Starting API tests...")
|
||||
|
||||
# Create an agent
|
||||
print("\n1. Creating agent...")
|
||||
agent_response = create_agent()
|
||||
if not agent_response:
|
||||
print("Failed to create agent")
|
||||
return
|
||||
|
||||
agent_id = agent_response.get("agent_id")
|
||||
print(f"Created agent with ID: {agent_id}")
|
||||
|
||||
# Give the server a moment to process
|
||||
sleep(2)
|
||||
|
||||
# List agents
|
||||
print("\n2. Listing agents...")
|
||||
agents = list_agents()
|
||||
print(f"Found {len(agents)} agents")
|
||||
|
||||
# Test completion
|
||||
if agent_id:
|
||||
print("\n3. Testing completion...")
|
||||
completion = test_completion(agent_id)
|
||||
if completion:
|
||||
print(
|
||||
f"Completion response: {completion.get('response')}"
|
||||
)
|
||||
|
||||
print("\n4. Getting agent metrics...")
|
||||
metrics = get_agent_metrics(agent_id)
|
||||
if metrics:
|
||||
print(f"Agent metrics: {json.dumps(metrics, indent=2)}")
|
||||
|
||||
# Clean up
|
||||
# print("\n5. Cleaning up - deleting agent...")
|
||||
# delete_result = delete_agent(agent_id)
|
||||
# if delete_result:
|
||||
# print("Successfully deleted agent")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,898 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Union, Optional
|
||||
import io
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import struct
|
||||
|
||||
|
||||
from enum import auto
|
||||
from typing import List, Dict, Tuple
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from loguru import logger
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for the enhanced BytePredictor model."""
|
||||
|
||||
vocab_size: int = 256 # Standard byte range
|
||||
hidden_size: int = 1024
|
||||
num_layers: int = 12
|
||||
num_key_value_heads: int = 8 # For multi-query attention
|
||||
num_query_heads: int = 32 # More query heads than kv heads
|
||||
dropout: float = 0.1
|
||||
max_sequence_length: int = 8192
|
||||
rope_theta: float = 10000.0
|
||||
layer_norm_eps: float = 1e-5
|
||||
vocab_parallel: bool = False
|
||||
qk_norm: bool = True
|
||||
qk_norm_scale: float = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
"""Fixed Multi-Query Attention implementation."""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_query_heads = config.num_query_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // config.num_query_heads
|
||||
self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_query_heads * self.head_dim
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_query_heads * self.head_dim, config.hidden_size
|
||||
)
|
||||
|
||||
self.qk_norm = config.qk_norm
|
||||
if self.qk_norm:
|
||||
self.q_norm = nn.LayerNorm(self.head_dim)
|
||||
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
# Project and reshape
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape to [seq_len, batch, heads, head_dim]
|
||||
q = q.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_query_heads,
|
||||
self.head_dim,
|
||||
).permute(1, 0, 2, 3)
|
||||
k = k.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_key_value_heads,
|
||||
self.head_dim,
|
||||
).permute(1, 0, 2, 3)
|
||||
v = v.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_key_value_heads,
|
||||
self.head_dim,
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
# Apply rotary embeddings
|
||||
# q, k = self.rotary(q, k, seq_length)
|
||||
|
||||
# Apply QK normalization if enabled
|
||||
if self.qk_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Handle MQA head expansion
|
||||
if self.num_key_value_heads != self.num_query_heads:
|
||||
k = k.repeat_interleave(
|
||||
self.num_query_heads // self.num_key_value_heads,
|
||||
dim=2,
|
||||
)
|
||||
v = v.repeat_interleave(
|
||||
self.num_query_heads // self.num_key_value_heads,
|
||||
dim=2,
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
# Reshape for matmul: [batch, heads, seq_length, head_dim]
|
||||
q = q.permute(1, 2, 0, 3)
|
||||
k = k.permute(1, 2, 0, 3)
|
||||
v = v.permute(1, 2, 0, 3)
|
||||
|
||||
attn_weights = (
|
||||
torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
|
||||
output = torch.matmul(attn_weights, v)
|
||||
|
||||
# Reshape back to [batch, seq_length, hidden_size]
|
||||
output = (
|
||||
output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, seq_length, -1)
|
||||
)
|
||||
output = self.o_proj(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class EnhancedBytePredictor(nn.Module):
|
||||
"""Enhanced byte prediction model with state-of-the-art features."""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Token embeddings
|
||||
self.tok_embeddings = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
|
||||
# Transformer layers
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
nn.ModuleDict(
|
||||
{
|
||||
"attention": MultiQueryAttention(config),
|
||||
"attention_norm": nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
),
|
||||
"feed_forward": nn.Sequential(
|
||||
nn.Linear(
|
||||
config.hidden_size,
|
||||
4 * config.hidden_size,
|
||||
),
|
||||
nn.GELU(),
|
||||
nn.Linear(
|
||||
4 * config.hidden_size,
|
||||
config.hidden_size,
|
||||
),
|
||||
),
|
||||
"feed_forward_norm": nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
),
|
||||
}
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.output = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
"""Initialize weights with scaled normal distribution."""
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
input_ids: Tensor of shape (batch_size, sequence_length)
|
||||
attention_mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
Tensor of logits with shape (batch_size, sequence_length, vocab_size)
|
||||
"""
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
|
||||
# Create causal mask if needed
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.triu(
|
||||
torch.ones(
|
||||
(input_ids.size(1), input_ids.size(1)),
|
||||
device=input_ids.device,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
diagonal=1,
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(
|
||||
attention_mask == 1, float("-inf")
|
||||
)
|
||||
|
||||
# Apply transformer layers
|
||||
for layer in self.layers:
|
||||
# Attention block
|
||||
hidden_states = hidden_states + layer["attention"](
|
||||
layer["attention_norm"](hidden_states), attention_mask
|
||||
)
|
||||
|
||||
# Feed-forward block
|
||||
hidden_states = hidden_states + layer["feed_forward"](
|
||||
layer["feed_forward_norm"](hidden_states)
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
logits = self.output(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
target_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute cross entropy loss.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids
|
||||
target_ids: Target token ids
|
||||
attention_mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
Loss value
|
||||
"""
|
||||
logits = self(input_ids, attention_mask)
|
||||
loss = F.cross_entropy(
|
||||
rearrange(logits, "b s v -> (b s) v"),
|
||||
rearrange(target_ids, "b s -> (b s)"),
|
||||
)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate new tokens autoregressively.
|
||||
|
||||
Args:
|
||||
input_ids: Starting sequence
|
||||
max_new_tokens: Number of tokens to generate
|
||||
temperature: Sampling temperature
|
||||
top_k: K for top-k sampling
|
||||
top_p: P for nucleus sampling
|
||||
repetition_penalty: Penalty for repeating tokens
|
||||
|
||||
Returns:
|
||||
Generated sequence
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
generated = input_ids.clone()
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if generated.size(1) >= self.config.max_sequence_length:
|
||||
break
|
||||
|
||||
# Forward pass
|
||||
logits = self(generated)[:, -1, :]
|
||||
|
||||
# Apply temperature
|
||||
logits = logits / temperature
|
||||
|
||||
# Apply repetition penalty
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size):
|
||||
for token_id in set(generated[i].tolist()):
|
||||
logits[i, token_id] /= repetition_penalty
|
||||
|
||||
# Apply top-k sampling
|
||||
if top_k is not None:
|
||||
indices_to_remove = (
|
||||
logits
|
||||
< torch.topk(logits, top_k)[0][..., -1, None]
|
||||
)
|
||||
logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Apply nucleus (top-p) sampling
|
||||
if top_p is not None:
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
logits, descending=True
|
||||
)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = (
|
||||
sorted_indices_to_remove[..., :-1].clone()
|
||||
)
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = torch.zeros_like(
|
||||
logits, dtype=torch.bool
|
||||
)
|
||||
indices_to_remove.scatter_(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Sample next token
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
# Append to sequence
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
return generated
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
tensor_data = self._generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
return tensor_to_data(tensor_data)
|
||||
|
||||
|
||||
# import torch
|
||||
# from typing import Optional
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
BINARY = "binary"
|
||||
|
||||
|
||||
class ByteDetokenizer:
|
||||
"""Utility class for converting model output bytes back to original data formats."""
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
|
||||
"""Convert model output tensor to bytes."""
|
||||
# Convert logits/probabilities to byte values
|
||||
if tensor.dim() > 1:
|
||||
# If we have logits, convert to byte indices
|
||||
byte_indices = tensor.argmax(dim=-1)
|
||||
else:
|
||||
byte_indices = tensor
|
||||
|
||||
# Convert to Python bytes
|
||||
return bytes(
|
||||
byte_indices.cpu().numpy().astype(np.uint8).tolist()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def decode_text(byte_sequence: bytes) -> str:
|
||||
"""Convert bytes to text."""
|
||||
try:
|
||||
return byte_sequence.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Try with error handling
|
||||
return byte_sequence.decode("utf-8", errors="replace")
|
||||
|
||||
@staticmethod
|
||||
def decode_image(
|
||||
byte_sequence: bytes,
|
||||
mode: str = "RGB",
|
||||
size: Optional[tuple] = None,
|
||||
) -> Image.Image:
|
||||
"""Convert bytes to image.
|
||||
|
||||
Args:
|
||||
byte_sequence: Raw image bytes
|
||||
mode: Image mode (RGB, RGBA, L, etc.)
|
||||
size: Optional tuple of (width, height)
|
||||
"""
|
||||
try:
|
||||
# Try to load as-is first (for standard image formats)
|
||||
img = Image.open(io.BytesIO(byte_sequence))
|
||||
if size:
|
||||
img = img.resize(size)
|
||||
return img
|
||||
except:
|
||||
# If failed, assume raw pixel data
|
||||
if not size:
|
||||
# Try to determine size from byte count
|
||||
pixel_count = len(byte_sequence) // len(mode)
|
||||
size = (
|
||||
int(np.sqrt(pixel_count)),
|
||||
int(np.sqrt(pixel_count)),
|
||||
)
|
||||
|
||||
# Convert raw bytes to pixel array
|
||||
pixels = np.frombuffer(byte_sequence, dtype=np.uint8)
|
||||
pixels = pixels.reshape((*size, len(mode)))
|
||||
|
||||
return Image.fromarray(pixels, mode=mode)
|
||||
|
||||
@staticmethod
|
||||
def decode_audio(
|
||||
byte_sequence: bytes,
|
||||
sample_rate: int = 44100,
|
||||
channels: int = 2,
|
||||
sample_width: int = 2,
|
||||
) -> np.ndarray:
|
||||
"""Convert bytes to audio samples.
|
||||
|
||||
Args:
|
||||
byte_sequence: Raw audio bytes
|
||||
sample_rate: Audio sample rate in Hz
|
||||
channels: Number of audio channels
|
||||
sample_width: Bytes per sample (1, 2, or 4)
|
||||
"""
|
||||
# Determine format string based on sample width
|
||||
format_str = {
|
||||
1: "b", # signed char
|
||||
2: "h", # short
|
||||
4: "i", # int
|
||||
}[sample_width]
|
||||
|
||||
# Unpack bytes to samples
|
||||
sample_count = len(byte_sequence) // (channels * sample_width)
|
||||
samples = struct.unpack(
|
||||
f"<{sample_count * channels}{format_str}", byte_sequence
|
||||
)
|
||||
|
||||
# Reshape to [samples, channels]
|
||||
return np.array(samples).reshape(-1, channels)
|
||||
|
||||
def decode_data(
|
||||
self,
|
||||
model_output: Union[torch.Tensor, bytes],
|
||||
data_type: DataType,
|
||||
**kwargs,
|
||||
) -> Union[str, Image.Image, np.ndarray, bytes]:
|
||||
"""Main method to decode model output to desired format.
|
||||
|
||||
Args:
|
||||
model_output: Either tensor from model or raw bytes
|
||||
data_type: Type of data to decode to
|
||||
**kwargs: Additional parameters for specific decoders
|
||||
|
||||
Returns:
|
||||
Decoded data in specified format
|
||||
"""
|
||||
# Convert tensor to bytes if needed
|
||||
if isinstance(model_output, torch.Tensor):
|
||||
byte_sequence = self.tensor_to_bytes(model_output)
|
||||
else:
|
||||
byte_sequence = model_output
|
||||
|
||||
# Decode based on type
|
||||
if data_type == DataType.TEXT:
|
||||
return self.decode_text(byte_sequence)
|
||||
elif data_type == DataType.IMAGE:
|
||||
return self.decode_image(byte_sequence, **kwargs)
|
||||
elif data_type == DataType.AUDIO:
|
||||
return self.decode_audio(byte_sequence, **kwargs)
|
||||
elif data_type == DataType.VIDEO:
|
||||
raise NotImplementedError(
|
||||
"Video decoding not yet implemented"
|
||||
)
|
||||
else: # BINARY
|
||||
return byte_sequence
|
||||
|
||||
|
||||
# Usage example
|
||||
|
||||
|
||||
class Modality(Enum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
BINARY = auto()
|
||||
MULTIMODAL = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModalityInfo:
|
||||
"""Information about detected modality."""
|
||||
|
||||
modality: Modality
|
||||
confidence: float
|
||||
metadata: Dict[str, any]
|
||||
sub_modalities: Optional[List["ModalityInfo"]] = None
|
||||
|
||||
|
||||
class ModalityDetector:
|
||||
"""Detects data modalities from byte sequences."""
|
||||
|
||||
# Common file signatures (magic numbers)
|
||||
SIGNATURES = {
|
||||
# Images
|
||||
b"\xFF\xD8\xFF": "JPEG",
|
||||
b"\x89PNG\r\n\x1a\n": "PNG",
|
||||
b"GIF87a": "GIF",
|
||||
b"GIF89a": "GIF",
|
||||
b"RIFF": "WEBP",
|
||||
# Audio
|
||||
b"RIFF....WAVE": "WAV",
|
||||
b"ID3": "MP3",
|
||||
b"\xFF\xFB": "MP3",
|
||||
b"OggS": "OGG",
|
||||
# Video
|
||||
b"\x00\x00\x00\x18ftypmp42": "MP4",
|
||||
b"\x00\x00\x00\x1Cftypav01": "MP4",
|
||||
b"\x1A\x45\xDF\xA3": "WEBM",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.magic = magic.Magic(mime=True)
|
||||
|
||||
def _check_text_probability(self, data: bytes) -> float:
|
||||
"""Estimate probability that data is text."""
|
||||
# Check if data is valid UTF-8
|
||||
try:
|
||||
data.decode("utf-8")
|
||||
# Count printable ASCII characters
|
||||
printable = sum(1 for b in data if 32 <= b <= 126)
|
||||
return printable / len(data)
|
||||
except UnicodeDecodeError:
|
||||
return 0.0
|
||||
|
||||
def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
||||
"""Check if data is a valid image and extract metadata."""
|
||||
try:
|
||||
with io.BytesIO(data) as bio:
|
||||
img = Image.open(bio)
|
||||
return True, {
|
||||
"format": img.format,
|
||||
"size": img.size,
|
||||
"mode": img.mode,
|
||||
}
|
||||
except:
|
||||
return False, {}
|
||||
|
||||
def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
||||
"""Check if data is valid audio and extract metadata."""
|
||||
try:
|
||||
with io.BytesIO(data) as bio:
|
||||
# Try to parse as WAV
|
||||
with wave.open(bio) as wav:
|
||||
return True, {
|
||||
"channels": wav.getnchannels(),
|
||||
"sample_width": wav.getsampwidth(),
|
||||
"framerate": wav.getframerate(),
|
||||
"frames": wav.getnframes(),
|
||||
}
|
||||
except:
|
||||
# Check for other audio signatures
|
||||
for sig in [b"ID3", b"\xFF\xFB", b"OggS"]:
|
||||
if data.startswith(sig):
|
||||
return True, {"format": "compressed_audio"}
|
||||
return False, {}
|
||||
|
||||
def _detect_boundaries(
|
||||
self, data: bytes
|
||||
) -> List[Tuple[int, int, Modality]]:
|
||||
"""Detect boundaries between different modalities in the data."""
|
||||
boundaries = []
|
||||
current_pos = 0
|
||||
|
||||
while current_pos < len(data):
|
||||
# Look for known signatures
|
||||
for sig, format_type in self.SIGNATURES.items():
|
||||
if data[current_pos:].startswith(sig):
|
||||
# Found a signature, determine its length
|
||||
if format_type in ["JPEG", "PNG", "GIF"]:
|
||||
# Find image end
|
||||
try:
|
||||
with io.BytesIO(
|
||||
data[current_pos:]
|
||||
) as bio:
|
||||
img = Image.open(bio)
|
||||
img.verify()
|
||||
size = bio.tell()
|
||||
boundaries.append(
|
||||
(
|
||||
current_pos,
|
||||
current_pos + size,
|
||||
Modality.IMAGE,
|
||||
)
|
||||
)
|
||||
current_pos += size
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check for text sections
|
||||
text_prob = self._check_text_probability(
|
||||
data[current_pos : current_pos + 1024]
|
||||
)
|
||||
if text_prob > 0.8:
|
||||
# Look for end of text section
|
||||
end_pos = current_pos + 1
|
||||
while end_pos < len(data):
|
||||
if (
|
||||
self._check_text_probability(
|
||||
data[end_pos : end_pos + 32]
|
||||
)
|
||||
< 0.5
|
||||
):
|
||||
break
|
||||
end_pos += 1
|
||||
boundaries.append(
|
||||
(current_pos, end_pos, Modality.TEXT)
|
||||
)
|
||||
current_pos = end_pos
|
||||
continue
|
||||
|
||||
current_pos += 1
|
||||
|
||||
return boundaries
|
||||
|
||||
def detect_modality(self, data: bytes) -> ModalityInfo:
|
||||
"""Detect modality of byte sequence."""
|
||||
# First check for single modality
|
||||
mime_type = self.magic.from_buffer(data)
|
||||
|
||||
# Check text
|
||||
text_prob = self._check_text_probability(data)
|
||||
if text_prob > 0.9:
|
||||
return ModalityInfo(
|
||||
modality=Modality.TEXT,
|
||||
confidence=text_prob,
|
||||
metadata={"mime_type": mime_type},
|
||||
)
|
||||
|
||||
# Check image
|
||||
is_image, image_meta = self._check_image_validity(data)
|
||||
if is_image:
|
||||
return ModalityInfo(
|
||||
modality=Modality.IMAGE,
|
||||
confidence=1.0,
|
||||
metadata={**image_meta, "mime_type": mime_type},
|
||||
)
|
||||
|
||||
# Check audio
|
||||
is_audio, audio_meta = self._check_audio_validity(data)
|
||||
if is_audio:
|
||||
return ModalityInfo(
|
||||
modality=Modality.AUDIO,
|
||||
confidence=1.0,
|
||||
metadata={**audio_meta, "mime_type": mime_type},
|
||||
)
|
||||
|
||||
# Check for multimodal content
|
||||
boundaries = self._detect_boundaries(data)
|
||||
if len(boundaries) > 1:
|
||||
sub_modalities = []
|
||||
for start, end, modality in boundaries:
|
||||
chunk_data = data[start:end]
|
||||
sub_info = self.detect_modality(chunk_data)
|
||||
if sub_info.modality != Modality.BINARY:
|
||||
sub_modalities.append(sub_info)
|
||||
|
||||
if sub_modalities:
|
||||
return ModalityInfo(
|
||||
modality=Modality.MULTIMODAL,
|
||||
confidence=0.8,
|
||||
metadata={"mime_type": "multipart/mixed"},
|
||||
sub_modalities=sub_modalities,
|
||||
)
|
||||
|
||||
# Default to binary
|
||||
return ModalityInfo(
|
||||
modality=Modality.BINARY,
|
||||
confidence=0.5,
|
||||
metadata={"mime_type": mime_type},
|
||||
)
|
||||
|
||||
def split_modalities(
|
||||
self, data: bytes
|
||||
) -> List[Tuple[Modality, bytes, Dict]]:
|
||||
"""Split multimodal data into separate modalities."""
|
||||
boundaries = self._detect_boundaries(data)
|
||||
result = []
|
||||
|
||||
for start, end, modality in boundaries:
|
||||
chunk = data[start:end]
|
||||
info = self.detect_modality(chunk)
|
||||
result.append((modality, chunk, info.metadata))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AutoDetectBytesDecoder:
|
||||
"""Decoder that automatically detects and decodes different modalities."""
|
||||
|
||||
def __init__(self):
|
||||
self.detector = ModalityDetector()
|
||||
self.text_decoder = ByteDetokenizer() # From previous example
|
||||
|
||||
def decode(
|
||||
self, data: bytes
|
||||
) -> Union[str, Image.Image, np.ndarray, List[any]]:
|
||||
"""Automatically detect and decode byte sequence."""
|
||||
info = self.detector.detect_modality(data)
|
||||
|
||||
if info.modality == Modality.MULTIMODAL:
|
||||
# Handle multimodal content
|
||||
parts = self.detector.split_modalities(data)
|
||||
return [
|
||||
self.decode(chunk) for modality, chunk, _ in parts
|
||||
]
|
||||
|
||||
if info.modality == Modality.TEXT:
|
||||
return self.text_decoder.decode_text(data)
|
||||
elif info.modality == Modality.IMAGE:
|
||||
return self.text_decoder.decode_image(data)
|
||||
elif info.modality == Modality.AUDIO:
|
||||
return self.text_decoder.decode_audio(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
# # Example usage
|
||||
# def demo_auto_detection():
|
||||
# """Demonstrate auto modality detection."""
|
||||
# # Create mixed content
|
||||
# text = "Hello, World!".encode('utf-8')
|
||||
|
||||
# # Create a small test image
|
||||
# img = Image.new('RGB', (100, 100), color='red')
|
||||
# img_bytes = io.BytesIO()
|
||||
# img.save(img_bytes, format='PNG')
|
||||
|
||||
# # Combine into multimodal content
|
||||
# mixed_content = text + img_bytes.getvalue()
|
||||
|
||||
# # Initialize decoder
|
||||
# decoder = AutoDetectBytesDecoder()
|
||||
|
||||
# # Decode
|
||||
# result = decoder.decode(mixed_content)
|
||||
|
||||
# if isinstance(result, list):
|
||||
# print("Detected multimodal content:")
|
||||
# for i, part in enumerate(result):
|
||||
# print(f"Part {i+1}: {type(part)}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# demo_auto_detection()
|
||||
|
||||
|
||||
def tensor_to_data(tensor: Tensor):
|
||||
byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor)
|
||||
|
||||
# Initialize auto-detector
|
||||
decoder = AutoDetectBytesDecoder()
|
||||
|
||||
# Decode with automatic detection
|
||||
result = decoder.decode(byte_sequence)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def demo_byte_predictor():
|
||||
"""Demo with smaller dimensions to test."""
|
||||
# Initialize model configuration with adjusted dimensions
|
||||
config = ModelConfig(
|
||||
vocab_size=256,
|
||||
hidden_size=128, # Smaller for testing
|
||||
num_layers=2, # Fewer layers for testing
|
||||
num_key_value_heads=2,
|
||||
num_query_heads=4,
|
||||
dropout=0.1,
|
||||
max_sequence_length=1024,
|
||||
)
|
||||
|
||||
# Initialize model
|
||||
model = EnhancedBytePredictor(config)
|
||||
logger.info("Model initialized")
|
||||
|
||||
# Move to GPU if available
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
model = model.to(device)
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Create sample input data
|
||||
batch_size = 2
|
||||
seq_length = 16 # Shorter sequence for testing
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_length), device=device
|
||||
)
|
||||
logger.info(f"Created input tensor of shape: {input_ids.shape}")
|
||||
|
||||
# Test forward pass
|
||||
try:
|
||||
logits = model(input_ids)
|
||||
logger.info(
|
||||
f"Forward pass successful! Output shape: {logits.shape}"
|
||||
)
|
||||
|
||||
# Test loss computation
|
||||
target_ids = torch.randint(
|
||||
0,
|
||||
config.vocab_size,
|
||||
(batch_size, seq_length),
|
||||
device=device,
|
||||
)
|
||||
loss = model.compute_loss(input_ids, target_ids)
|
||||
logger.info(
|
||||
f"Loss computation successful! Loss value: {loss.item():.4f}"
|
||||
)
|
||||
|
||||
# Test generation
|
||||
prompt = torch.randint(
|
||||
0,
|
||||
config.vocab_size,
|
||||
(1, 4), # Very short prompt for testing
|
||||
device=device,
|
||||
)
|
||||
generated = model.generate(
|
||||
prompt, max_new_tokens=8, temperature=0.8, top_k=50
|
||||
)
|
||||
logger.info(
|
||||
f"Generation successful! Generated shape: {generated.shape}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during execution: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set up logging
|
||||
# logger.remove() # Remove default handler
|
||||
# logger.add(sys.stderr, format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
||||
|
||||
demo_byte_predictor()
|
@ -0,0 +1,8 @@
|
||||
from swarms import Agent
|
||||
|
||||
Agent(
|
||||
agent_name="Stock-Analysis-Agent",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
).run("What are 5 hft algorithms")
|
@ -0,0 +1,618 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Dict
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
"""Configuration class for MoE Transformer model parameters."""
|
||||
|
||||
vocab_size: int = 50257
|
||||
hidden_size: int = 768
|
||||
num_attention_heads: int = 12
|
||||
num_expert_layers: int = 4
|
||||
num_experts: int = 8
|
||||
expert_capacity: int = 32
|
||||
max_position_embeddings: int = 1024
|
||||
dropout_prob: float = 0.1
|
||||
layer_norm_epsilon: float = 1e-5
|
||||
initializer_range: float = 0.02
|
||||
num_query_groups: int = 4 # For multi-query attention
|
||||
|
||||
|
||||
class ExpertLayer(nn.Module):
|
||||
"""Individual expert neural network."""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(
|
||||
config.hidden_size, 4 * config.hidden_size
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
4 * config.hidden_size, config.hidden_size
|
||||
)
|
||||
self.activation = nn.GELU()
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MixtureOfExperts(nn.Module):
|
||||
"""Mixture of Experts layer with dynamic routing."""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.expert_capacity = config.expert_capacity
|
||||
|
||||
# Create expert networks
|
||||
self.experts = nn.ModuleList(
|
||||
[ExpertLayer(config) for _ in range(config.num_experts)]
|
||||
)
|
||||
|
||||
# Router network
|
||||
self.router = nn.Linear(
|
||||
config.hidden_size, config.num_experts
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||
"""Route inputs to experts and combine outputs."""
|
||||
batch_size, seq_len, hidden_size = x.shape
|
||||
|
||||
# Calculate routing probabilities
|
||||
router_logits = self.router(x)
|
||||
routing_weights = F.softmax(router_logits, dim=-1)
|
||||
|
||||
# Select top-k experts
|
||||
top_k = 2
|
||||
gates, indices = torch.topk(routing_weights, top_k, dim=-1)
|
||||
gates = F.softmax(gates, dim=-1)
|
||||
|
||||
# Process inputs through selected experts
|
||||
final_output = torch.zeros_like(x)
|
||||
router_load = torch.zeros(self.num_experts, device=x.device)
|
||||
|
||||
for i in range(top_k):
|
||||
expert_index = indices[..., i]
|
||||
gate = gates[..., i : i + 1]
|
||||
|
||||
# Count expert assignments
|
||||
for j in range(self.num_experts):
|
||||
router_load[j] += (expert_index == j).float().sum()
|
||||
|
||||
# Process through selected experts
|
||||
for j in range(self.num_experts):
|
||||
mask = expert_index == j
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
expert_input = x[mask]
|
||||
expert_output = self.experts[j](expert_input)
|
||||
final_output[mask] += gate[mask] * expert_output
|
||||
|
||||
aux_loss = router_load.float().var() / (
|
||||
router_load.float().mean() ** 2
|
||||
)
|
||||
|
||||
return final_output, {"load_balancing_loss": aux_loss}
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
"""Multi-Query Attention mechanism with proper multi-query group handling."""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_query_groups = config.num_query_groups
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# Query projection maintains full head dimension
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.hidden_size
|
||||
)
|
||||
|
||||
# Key and value projections use reduced number of heads (query groups)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.head_dim * config.num_query_groups,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.head_dim * config.num_query_groups,
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
|
||||
# Calculate heads per group for proper reshaping
|
||||
self.heads_per_group = (
|
||||
self.num_attention_heads // self.num_query_groups
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
cache: Optional[Dict[str, Tensor]] = None,
|
||||
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
# Project queries, keys, and values
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape queries to full number of heads
|
||||
queries = queries.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_attention_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
# Reshape keys and values to number of query groups
|
||||
keys = keys.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_query_groups,
|
||||
self.head_dim,
|
||||
)
|
||||
values = values.view(
|
||||
batch_size,
|
||||
seq_length,
|
||||
self.num_query_groups,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
# Transpose for batch matrix multiplication
|
||||
queries = queries.transpose(
|
||||
1, 2
|
||||
) # (batch, n_heads, seq_len, head_dim)
|
||||
keys = keys.transpose(
|
||||
1, 2
|
||||
) # (batch, n_groups, seq_len, head_dim)
|
||||
values = values.transpose(
|
||||
1, 2
|
||||
) # (batch, n_groups, seq_len, head_dim)
|
||||
|
||||
# Repeat keys and values for each head in the group
|
||||
keys = keys.repeat_interleave(self.heads_per_group, dim=1)
|
||||
values = values.repeat_interleave(self.heads_per_group, dim=1)
|
||||
|
||||
# Compute attention scores
|
||||
scale = 1.0 / math.sqrt(self.head_dim)
|
||||
scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale
|
||||
|
||||
if attention_mask is not None:
|
||||
# Expand attention mask to match scores dimensions
|
||||
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
expanded_mask = expanded_mask.expand(
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
seq_length,
|
||||
seq_length,
|
||||
)
|
||||
mask_value = torch.finfo(scores.dtype).min
|
||||
attention_mask = expanded_mask.eq(0).float() * mask_value
|
||||
scores = scores + attention_mask
|
||||
|
||||
attention_weights = F.softmax(scores, dim=-1)
|
||||
attention_weights = self.dropout(attention_weights)
|
||||
|
||||
# Compute attention output
|
||||
attention_output = torch.matmul(attention_weights, values)
|
||||
attention_output = attention_output.transpose(1, 2)
|
||||
attention_output = attention_output.reshape(
|
||||
batch_size, seq_length, -1
|
||||
)
|
||||
|
||||
return attention_output, None
|
||||
|
||||
|
||||
class MoETransformer(nn.Module):
|
||||
"""
|
||||
Production-grade Transformer model with Mixture of Experts and Multi-Query Attention.
|
||||
|
||||
Features:
|
||||
- Multi-Query Attention mechanism for efficient inference
|
||||
- Mixture of Experts for dynamic routing and specialization
|
||||
- Real-time weight updates based on input similarity
|
||||
- Built-in logging and monitoring
|
||||
- Type annotations for better code maintainability
|
||||
"""
|
||||
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Initialize components
|
||||
self.embedding = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size
|
||||
)
|
||||
|
||||
# Multi-Query Attention layers
|
||||
self.attention_layers = nn.ModuleList(
|
||||
[
|
||||
MultiQueryAttention(config)
|
||||
for _ in range(config.num_expert_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Mixture of Experts layers
|
||||
self.moe_layers = nn.ModuleList(
|
||||
[
|
||||
MixtureOfExperts(config)
|
||||
for _ in range(config.num_expert_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Layer normalization and dropout
|
||||
self.layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
|
||||
# Output projection
|
||||
self.output_projection = nn.Linear(
|
||||
config.hidden_size, config.vocab_size
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
logger.info("Initialized MoETransformer model")
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize model weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=self.config.initializer_range
|
||||
)
|
||||
if (
|
||||
isinstance(module, nn.Linear)
|
||||
and module.bias is not None
|
||||
):
|
||||
module.bias.data.zero_()
|
||||
|
||||
def get_position_embeddings(self, position_ids: Tensor) -> Tensor:
|
||||
"""Generate position embeddings."""
|
||||
return self.position_embedding(position_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
cache: Optional[Dict[str, Tensor]] = None,
|
||||
) -> Tuple[Tensor, Dict]:
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
attention_mask: Attention mask for padding
|
||||
position_ids: Position IDs for positioning encoding
|
||||
cache: Cache for key/value states in generation
|
||||
|
||||
Returns:
|
||||
tuple: (logits, auxiliary_outputs)
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(
|
||||
input_ids
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
position_embeds = self.get_position_embeddings(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
# Initialize auxiliary outputs
|
||||
aux_outputs = {"moe_losses": []}
|
||||
|
||||
# Process through transformer layers
|
||||
for attention_layer, moe_layer in zip(
|
||||
self.attention_layers, self.moe_layers
|
||||
):
|
||||
# Multi-Query Attention
|
||||
attention_output, _ = attention_layer(
|
||||
hidden_states, attention_mask, cache
|
||||
)
|
||||
hidden_states = self.layer_norm(
|
||||
hidden_states + attention_output
|
||||
)
|
||||
|
||||
# Mixture of Experts
|
||||
moe_output, moe_aux = moe_layer(hidden_states)
|
||||
hidden_states = self.layer_norm(
|
||||
hidden_states + moe_output
|
||||
)
|
||||
aux_outputs["moe_losses"].append(
|
||||
moe_aux["load_balancing_loss"]
|
||||
)
|
||||
|
||||
# Final output projection
|
||||
logits = self.output_projection(hidden_states)
|
||||
|
||||
return logits, aux_outputs
|
||||
|
||||
def fetch_loss(
|
||||
self,
|
||||
logits: Tensor,
|
||||
labels: Tensor,
|
||||
aux_outputs: Dict,
|
||||
reduction: str = "mean",
|
||||
) -> Tensor:
|
||||
"""
|
||||
Calculate the total loss including MoE balancing losses.
|
||||
|
||||
Args:
|
||||
logits: Model output logits
|
||||
labels: Ground truth labels
|
||||
aux_outputs: Auxiliary outputs from forward pass
|
||||
reduction: Loss reduction method
|
||||
|
||||
Returns:
|
||||
Tensor: Total loss
|
||||
"""
|
||||
# Calculate cross entropy loss
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.view(-1, self.config.vocab_size),
|
||||
labels.view(-1),
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
# Calculate MoE loss
|
||||
moe_loss = torch.stack(aux_outputs["moe_losses"]).mean()
|
||||
|
||||
# Combine losses
|
||||
total_loss = ce_loss + 0.01 * moe_loss
|
||||
|
||||
logger.debug(
|
||||
f"CE Loss: {ce_loss.item():.4f}, "
|
||||
f"MoE Loss: {moe_loss.item():.4f}"
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
max_length: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 0.9,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Generate text using the model.
|
||||
|
||||
Args:
|
||||
input_ids: Initial input tokens
|
||||
max_length: Maximum sequence length to generate
|
||||
temperature: Sampling temperature
|
||||
top_k: Number of highest probability tokens to keep
|
||||
top_p: Cumulative probability for nucleus sampling
|
||||
|
||||
Returns:
|
||||
Tensor: Generated token IDs
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
|
||||
# Initialize sequence with input_ids
|
||||
generated = input_ids
|
||||
|
||||
# Cache for key-value pairs
|
||||
cache = {}
|
||||
|
||||
for _ in range(max_length):
|
||||
# Get position IDs for current sequence
|
||||
position_ids = torch.arange(
|
||||
generated.shape[1], dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand(
|
||||
batch_size, -1
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
logits, _ = self.forward(
|
||||
generated, position_ids=position_ids, cache=cache
|
||||
)
|
||||
|
||||
# Get next token logits
|
||||
next_token_logits = logits[:, -1, :] / temperature
|
||||
|
||||
# Apply top-k filtering
|
||||
if top_k > 0:
|
||||
indices_to_remove = (
|
||||
next_token_logits
|
||||
< torch.topk(next_token_logits, top_k)[0][
|
||||
..., -1, None
|
||||
]
|
||||
)
|
||||
next_token_logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Apply top-p (nucleus) filtering
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
next_token_logits, descending=True
|
||||
)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = (
|
||||
sorted_indices_to_remove[..., :-1].clone()
|
||||
)
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[
|
||||
sorted_indices_to_remove
|
||||
]
|
||||
next_token_logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Sample next token
|
||||
probs = F.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
# Append next token to sequence
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
|
||||
# Check for end of sequence token
|
||||
if (next_token == self.config.vocab_size - 1).all():
|
||||
break
|
||||
|
||||
return generated
|
||||
|
||||
|
||||
# Initialize model configuration
|
||||
config = TransformerConfig(
|
||||
vocab_size=50257,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
num_expert_layers=4,
|
||||
num_experts=8,
|
||||
expert_capacity=32,
|
||||
max_position_embeddings=1024,
|
||||
num_query_groups=4,
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_data(
|
||||
batch_size: int = 8,
|
||||
seq_length: int = 512,
|
||||
vocab_size: int = 50257,
|
||||
) -> DataLoader:
|
||||
"""Create sample data for demonstration."""
|
||||
# Create random input sequences
|
||||
input_ids = torch.randint(
|
||||
0, vocab_size, (100, seq_length) # 100 samples
|
||||
)
|
||||
|
||||
# Create target sequences (shifted by 1)
|
||||
labels = torch.randint(0, vocab_size, (100, seq_length))
|
||||
|
||||
# Create attention masks (1 for real tokens, 0 for padding)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = TensorDataset(input_ids, attention_mask, labels)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=True
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def train_step(
|
||||
model: MoETransformer,
|
||||
batch: tuple,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
) -> float:
|
||||
"""Execute single training step."""
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Unpack batch
|
||||
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
||||
|
||||
# Forward pass
|
||||
logits, aux_outputs = model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# Calculate loss
|
||||
loss = model.fetch_loss(logits, labels, aux_outputs)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
|
||||
def main():
|
||||
# Set device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Initialize model
|
||||
model = MoETransformer(config).to(device)
|
||||
logger.info("Model initialized")
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=1e-4, weight_decay=0.01
|
||||
)
|
||||
|
||||
# Prepare data
|
||||
dataloader = prepare_sample_data()
|
||||
logger.info("Data prepared")
|
||||
|
||||
# Training loop
|
||||
num_epochs = 3
|
||||
for epoch in range(num_epochs):
|
||||
epoch_losses = []
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
loss = train_step(model, batch, optimizer, device)
|
||||
epoch_losses.append(loss)
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{num_epochs} "
|
||||
f"Batch {batch_idx}/{len(dataloader)} "
|
||||
f"Loss: {loss:.4f}"
|
||||
)
|
||||
|
||||
avg_loss = np.mean(epoch_losses)
|
||||
logger.info(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
|
||||
|
||||
# Generation example
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# Prepare input prompt
|
||||
prompt = torch.randint(0, config.vocab_size, (1, 10)).to(
|
||||
device
|
||||
)
|
||||
|
||||
# Generate sequence
|
||||
generated = model.generate(
|
||||
input_ids=prompt,
|
||||
max_length=50,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
)
|
||||
|
||||
logger.info(f"Generated sequence shape: {generated.shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,121 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from supabase import Client, create_client
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Supabase client
|
||||
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
||||
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
||||
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
||||
|
||||
# Swarms API URL and headers
|
||||
SWARMS_API_URL = "https://swarms.world/api/add-prompt"
|
||||
SWARMS_API_KEY = os.getenv("SWARMS_API_KEY")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {SWARMS_API_KEY}",
|
||||
}
|
||||
|
||||
# Configure logger
|
||||
logger.add(
|
||||
"fetch_and_publish_prompts.log", rotation="1 MB"
|
||||
) # Log file with rotation
|
||||
|
||||
|
||||
def fetch_and_publish_prompts():
|
||||
logger.info("Starting to fetch and publish prompts.")
|
||||
|
||||
# Fetch data from Supabase
|
||||
try:
|
||||
response = (
|
||||
supabase.table("swarms_framework_schema")
|
||||
.select("*")
|
||||
.execute()
|
||||
)
|
||||
rows = response.data
|
||||
logger.info(f"Fetched {len(rows)} rows from Supabase.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch data from Supabase: {e}")
|
||||
return
|
||||
|
||||
# Track published prompts to avoid duplicates
|
||||
published_prompts = set()
|
||||
|
||||
for row in rows:
|
||||
# Extract agent_name and system_prompt
|
||||
data = row.get("data", {})
|
||||
agent_name = data.get("agent_name")
|
||||
system_prompt = data.get("system_prompt")
|
||||
|
||||
# Skip if either is missing or duplicate
|
||||
if not agent_name or not system_prompt:
|
||||
logger.warning(
|
||||
f"Skipping row due to missing agent_name or system_prompt: {row}"
|
||||
)
|
||||
continue
|
||||
if is_duplicate(system_prompt, published_prompts):
|
||||
logger.info(
|
||||
f"Skipping duplicate prompt for agent: {agent_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create the data payload for the marketplace
|
||||
prompt_data = {
|
||||
"name": f"{agent_name} - System Prompt",
|
||||
"prompt": system_prompt,
|
||||
"description": f"System prompt for agent {agent_name}.",
|
||||
"useCases": extract_use_cases(system_prompt),
|
||||
"tags": "agent, system-prompt",
|
||||
}
|
||||
|
||||
# Publish to the marketplace
|
||||
try:
|
||||
response = requests.post(
|
||||
SWARMS_API_URL,
|
||||
headers=headers,
|
||||
data=json.dumps(prompt_data),
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f"Successfully published prompt for agent: {agent_name}"
|
||||
)
|
||||
published_prompts.add(system_prompt)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to publish prompt for agent: {agent_name}. Response: {response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Exception occurred while publishing prompt for agent: {agent_name}. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def is_duplicate(new_prompt, published_prompts):
|
||||
"""Check if the prompt is a duplicate using semantic similarity."""
|
||||
for prompt in published_prompts:
|
||||
similarity = SequenceMatcher(None, new_prompt, prompt).ratio()
|
||||
if (
|
||||
similarity > 0.9
|
||||
): # Threshold for considering prompts as duplicates
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_use_cases(prompt):
|
||||
"""Extract use cases from the prompt by chunking it into meaningful segments."""
|
||||
# This is a simple placeholder; you can use a more advanced method to extract use cases
|
||||
chunks = [prompt[i : i + 50] for i in range(0, len(prompt), 50)]
|
||||
return [
|
||||
{"title": f"Use case {idx+1}", "description": chunk}
|
||||
for idx, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
|
||||
# Main execution
|
||||
fetch_and_publish_prompts()
|
@ -1,121 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from supabase import Client, create_client
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Supabase client
|
||||
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
||||
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
||||
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
||||
|
||||
# Swarms API URL and headers
|
||||
SWARMS_API_URL = "https://swarms.world/api/add-prompt"
|
||||
SWARMS_API_KEY = os.getenv("SWARMS_API_KEY")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {SWARMS_API_KEY}",
|
||||
}
|
||||
|
||||
# Configure logger
|
||||
logger.add(
|
||||
"fetch_and_publish_prompts.log", rotation="1 MB"
|
||||
) # Log file with rotation
|
||||
|
||||
|
||||
def fetch_and_publish_prompts():
|
||||
logger.info("Starting to fetch and publish prompts.")
|
||||
|
||||
# Fetch data from Supabase
|
||||
try:
|
||||
response = (
|
||||
supabase.table("swarms_framework_schema")
|
||||
.select("*")
|
||||
.execute()
|
||||
)
|
||||
rows = response.data
|
||||
logger.info(f"Fetched {len(rows)} rows from Supabase.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch data from Supabase: {e}")
|
||||
return
|
||||
|
||||
# Track published prompts to avoid duplicates
|
||||
published_prompts = set()
|
||||
|
||||
for row in rows:
|
||||
# Extract agent_name and system_prompt
|
||||
data = row.get("data", {})
|
||||
agent_name = data.get("agent_name")
|
||||
system_prompt = data.get("system_prompt")
|
||||
|
||||
# Skip if either is missing or duplicate
|
||||
if not agent_name or not system_prompt:
|
||||
logger.warning(
|
||||
f"Skipping row due to missing agent_name or system_prompt: {row}"
|
||||
)
|
||||
continue
|
||||
if is_duplicate(system_prompt, published_prompts):
|
||||
logger.info(
|
||||
f"Skipping duplicate prompt for agent: {agent_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create the data payload for the marketplace
|
||||
prompt_data = {
|
||||
"name": f"{agent_name} - System Prompt",
|
||||
"prompt": system_prompt,
|
||||
"description": f"System prompt for agent {agent_name}.",
|
||||
"useCases": extract_use_cases(system_prompt),
|
||||
"tags": "agent, system-prompt",
|
||||
}
|
||||
|
||||
# Publish to the marketplace
|
||||
try:
|
||||
response = requests.post(
|
||||
SWARMS_API_URL,
|
||||
headers=headers,
|
||||
data=json.dumps(prompt_data),
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f"Successfully published prompt for agent: {agent_name}"
|
||||
)
|
||||
published_prompts.add(system_prompt)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to publish prompt for agent: {agent_name}. Response: {response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Exception occurred while publishing prompt for agent: {agent_name}. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def is_duplicate(new_prompt, published_prompts):
|
||||
"""Check if the prompt is a duplicate using semantic similarity."""
|
||||
for prompt in published_prompts:
|
||||
similarity = SequenceMatcher(None, new_prompt, prompt).ratio()
|
||||
if (
|
||||
similarity > 0.9
|
||||
): # Threshold for considering prompts as duplicates
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_use_cases(prompt):
|
||||
"""Extract use cases from the prompt by chunking it into meaningful segments."""
|
||||
# This is a simple placeholder; you can use a more advanced method to extract use cases
|
||||
chunks = [prompt[i : i + 50] for i in range(0, len(prompt), 50)]
|
||||
return [
|
||||
{"title": f"Use case {idx+1}", "description": chunk}
|
||||
for idx, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
|
||||
# Main execution
|
||||
fetch_and_publish_prompts()
|
@ -0,0 +1,433 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import aiohttp
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from swarms import Agent
|
||||
|
||||
TREND_AGENT_PROMPT = """You are a specialized blockchain trend analysis agent. Your role:
|
||||
1. Analyze transaction patterns in Solana blockchain data
|
||||
2. Identify volume trends, price movements, and temporal patterns
|
||||
3. Focus on whale movements and their market impact
|
||||
4. Format findings in clear, structured JSON
|
||||
5. Include confidence scores for each insight
|
||||
6. Flag unusual patterns or anomalies
|
||||
7. Provide historical context for significant movements
|
||||
|
||||
Output format:
|
||||
{
|
||||
"trends": [
|
||||
{"pattern": str, "confidence": float, "impact": str}
|
||||
],
|
||||
"whale_activity": {...},
|
||||
"temporal_analysis": {...}
|
||||
}"""
|
||||
|
||||
RISK_AGENT_PROMPT = """You are a blockchain risk assessment specialist. Your tasks:
|
||||
1. Identify suspicious transaction patterns
|
||||
2. Monitor for known exploit signatures
|
||||
3. Assess wallet clustering and relationship patterns
|
||||
4. Evaluate transaction velocity and size anomalies
|
||||
5. Check for bridge-related risks
|
||||
6. Monitor smart contract interactions
|
||||
7. Flag potential wash trading
|
||||
|
||||
Output format:
|
||||
{
|
||||
"risk_score": float,
|
||||
"flags": [...],
|
||||
"recommendations": [...]
|
||||
}"""
|
||||
|
||||
SUMMARY_AGENT_PROMPT = """You are a blockchain data synthesis expert. Your responsibilities:
|
||||
1. Combine insights from trend and risk analyses
|
||||
2. Prioritize actionable intelligence
|
||||
3. Highlight critical patterns
|
||||
4. Generate executive summaries
|
||||
5. Provide market context
|
||||
6. Make predictions with confidence intervals
|
||||
7. Suggest trading strategies based on data
|
||||
|
||||
Output format:
|
||||
{
|
||||
"key_insights": [...],
|
||||
"market_impact": str,
|
||||
"recommendations": {...}
|
||||
}"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transaction:
|
||||
signature: str
|
||||
timestamp: datetime
|
||||
amount: float
|
||||
from_address: str
|
||||
to_address: str
|
||||
|
||||
|
||||
class SolanaRPC:
|
||||
def __init__(
|
||||
self, endpoint="https://api.mainnet-beta.solana.com"
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.session = None
|
||||
|
||||
async def get_signatures(self, address: str) -> List[Dict]:
|
||||
if not self.session:
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "getSignaturesForAddress",
|
||||
"params": [address, {"limit": 100}],
|
||||
}
|
||||
|
||||
async with self.session.post(
|
||||
self.endpoint, json=payload
|
||||
) as response:
|
||||
result = await response.json()
|
||||
return result.get("result", [])
|
||||
|
||||
async def get_transaction(self, signature: str) -> Dict:
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "getTransaction",
|
||||
"params": [
|
||||
signature,
|
||||
{
|
||||
"encoding": "json",
|
||||
"maxSupportedTransactionVersion": 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async with self.session.post(
|
||||
self.endpoint, json=payload
|
||||
) as response:
|
||||
result = await response.json()
|
||||
return result.get("result", {})
|
||||
|
||||
|
||||
class AlertSystem:
|
||||
def __init__(self, email: str, threshold: float = 1000.0):
|
||||
self.email = email
|
||||
self.threshold = threshold
|
||||
self.smtp_server = "smtp.gmail.com"
|
||||
self.smtp_port = 587
|
||||
|
||||
async def check_and_alert(
|
||||
self, transaction: Transaction, risk_score: float
|
||||
):
|
||||
if transaction.amount > self.threshold or risk_score > 0.8:
|
||||
await self.send_alert(transaction, risk_score)
|
||||
|
||||
async def send_alert(
|
||||
self, transaction: Transaction, risk_score: float
|
||||
):
|
||||
# msg = MIMEText(
|
||||
# f"High-risk transaction detected:\n"
|
||||
# f"Amount: {transaction.amount} SOL\n"
|
||||
# f"Risk Score: {risk_score}\n"
|
||||
# f"Signature: {transaction.signature}"
|
||||
# )
|
||||
logger.info(
|
||||
f"Alert sent for transaction {transaction.signature}"
|
||||
)
|
||||
|
||||
|
||||
class WalletClusterAnalyzer:
|
||||
def __init__(self):
|
||||
self.graph = nx.Graph()
|
||||
self.known_wallets: Set[str] = set()
|
||||
|
||||
def update_graph(self, transaction: Transaction):
|
||||
self.graph.add_edge(
|
||||
transaction.from_address,
|
||||
transaction.to_address,
|
||||
weight=transaction.amount,
|
||||
)
|
||||
self.known_wallets.add(transaction.from_address)
|
||||
self.known_wallets.add(transaction.to_address)
|
||||
|
||||
def identify_clusters(self) -> Dict:
|
||||
communities = nx.community.greedy_modularity_communities(
|
||||
self.graph
|
||||
)
|
||||
return {
|
||||
"clusters": [list(c) for c in communities],
|
||||
"central_wallets": [
|
||||
wallet
|
||||
for wallet in self.known_wallets
|
||||
if self.graph.degree[wallet] > 5
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TransactionVisualizer:
|
||||
def __init__(self):
|
||||
self.transaction_history = []
|
||||
|
||||
def add_transaction(self, transaction: Transaction):
|
||||
self.transaction_history.append(asdict(transaction))
|
||||
|
||||
def generate_volume_chart(self) -> str:
|
||||
volumes = [tx["amount"] for tx in self.transaction_history]
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(volumes)
|
||||
plt.title("Transaction Volume Over Time")
|
||||
plt.savefig("volume_chart.png")
|
||||
return "volume_chart.png"
|
||||
|
||||
def generate_network_graph(
|
||||
self, wallet_analyzer: WalletClusterAnalyzer
|
||||
) -> str:
|
||||
plt.figure(figsize=(15, 15))
|
||||
pos = nx.spring_layout(wallet_analyzer.graph)
|
||||
nx.draw(
|
||||
wallet_analyzer.graph,
|
||||
pos,
|
||||
node_size=1000,
|
||||
node_color="lightblue",
|
||||
with_labels=True,
|
||||
)
|
||||
plt.savefig("network_graph.png")
|
||||
return "network_graph.png"
|
||||
|
||||
|
||||
class SolanaMultiAgentAnalyzer:
|
||||
def __init__(
|
||||
self,
|
||||
min_amount: float = 50.0,
|
||||
websocket_url: str = "wss://api.mainnet-beta.solana.com",
|
||||
alert_email: str = None,
|
||||
):
|
||||
self.rpc = SolanaRPC()
|
||||
self.websocket_url = websocket_url
|
||||
self.min_amount = min_amount
|
||||
self.transactions = []
|
||||
|
||||
self.wallet_analyzer = WalletClusterAnalyzer()
|
||||
self.visualizer = TransactionVisualizer()
|
||||
self.alert_system = (
|
||||
AlertSystem(alert_email) if alert_email else None
|
||||
)
|
||||
|
||||
self.trend_agent = Agent(
|
||||
agent_name="trend-analyzer",
|
||||
system_prompt=TREND_AGENT_PROMPT,
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
)
|
||||
|
||||
self.risk_agent = Agent(
|
||||
agent_name="risk-analyzer",
|
||||
system_prompt=RISK_AGENT_PROMPT,
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
)
|
||||
|
||||
self.summary_agent = Agent(
|
||||
agent_name="summary-agent",
|
||||
system_prompt=SUMMARY_AGENT_PROMPT,
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
)
|
||||
|
||||
logger.add(
|
||||
"solana_analysis.log", rotation="500 MB", level="INFO"
|
||||
)
|
||||
|
||||
async def start_websocket_stream(self):
|
||||
async with websockets.connect(
|
||||
self.websocket_url
|
||||
) as websocket:
|
||||
subscribe_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "programSubscribe",
|
||||
"params": [
|
||||
"11111111111111111111111111111111",
|
||||
{"encoding": "json", "commitment": "confirmed"},
|
||||
],
|
||||
}
|
||||
await websocket.send(json.dumps(subscribe_message))
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
transaction = await self.parse_websocket_message(
|
||||
msg
|
||||
)
|
||||
if (
|
||||
transaction
|
||||
and transaction.amount >= self.min_amount
|
||||
):
|
||||
await self.process_transaction(transaction)
|
||||
except Exception as e:
|
||||
logger.error(f"Websocket error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def parse_websocket_message(
|
||||
self, msg: str
|
||||
) -> Optional[Transaction]:
|
||||
try:
|
||||
data = json.loads(msg)
|
||||
if "params" in data and "result" in data["params"]:
|
||||
tx_data = data["params"]["result"]
|
||||
return Transaction(
|
||||
signature=tx_data["signature"],
|
||||
timestamp=datetime.fromtimestamp(
|
||||
tx_data["blockTime"]
|
||||
),
|
||||
amount=float(
|
||||
tx_data["meta"]["postBalances"][0]
|
||||
- tx_data["meta"]["preBalances"][0]
|
||||
)
|
||||
/ 1e9,
|
||||
from_address=tx_data["transaction"]["message"][
|
||||
"accountKeys"
|
||||
][0],
|
||||
to_address=tx_data["transaction"]["message"][
|
||||
"accountKeys"
|
||||
][1],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing websocket message: {e}")
|
||||
return None
|
||||
|
||||
async def process_transaction(self, transaction: Transaction):
|
||||
self.wallet_analyzer.update_graph(transaction)
|
||||
self.visualizer.add_transaction(transaction)
|
||||
|
||||
risk_analysis = await self.risk_agent.run(
|
||||
f"Analyze risk for transaction: {json.dumps(asdict(transaction))}"
|
||||
)
|
||||
|
||||
if self.alert_system:
|
||||
await self.alert_system.check_and_alert(
|
||||
transaction, risk_analysis.get("risk_score", 0)
|
||||
)
|
||||
|
||||
async def fetch_transactions(self) -> List[Transaction]:
|
||||
try:
|
||||
signatures = await self.rpc.get_signatures(
|
||||
"11111111111111111111111111111111"
|
||||
)
|
||||
transactions = []
|
||||
|
||||
for sig_info in signatures:
|
||||
tx_data = await self.rpc.get_transaction(
|
||||
sig_info["signature"]
|
||||
)
|
||||
if not tx_data or "meta" not in tx_data:
|
||||
continue
|
||||
|
||||
pre_balances = tx_data["meta"]["preBalances"]
|
||||
post_balances = tx_data["meta"]["postBalances"]
|
||||
amount = abs(pre_balances[0] - post_balances[0]) / 1e9
|
||||
|
||||
if amount >= self.min_amount:
|
||||
tx = Transaction(
|
||||
signature=sig_info["signature"],
|
||||
timestamp=datetime.fromtimestamp(
|
||||
tx_data["blockTime"]
|
||||
),
|
||||
amount=amount,
|
||||
from_address=tx_data["transaction"][
|
||||
"message"
|
||||
]["accountKeys"][0],
|
||||
to_address=tx_data["transaction"]["message"][
|
||||
"accountKeys"
|
||||
][1],
|
||||
)
|
||||
transactions.append(tx)
|
||||
|
||||
return transactions
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching transactions: {e}")
|
||||
return []
|
||||
|
||||
async def analyze_transactions(
|
||||
self, transactions: List[Transaction]
|
||||
) -> Dict:
|
||||
tx_data = [asdict(tx) for tx in transactions]
|
||||
cluster_data = self.wallet_analyzer.identify_clusters()
|
||||
|
||||
trend_analysis = await self.trend_agent.run(
|
||||
f"Analyze trends in: {json.dumps(tx_data)}"
|
||||
)
|
||||
print(trend_analysis)
|
||||
|
||||
risk_analysis = await self.risk_agent.run(
|
||||
f"Analyze risks in: {json.dumps({'transactions': tx_data, 'clusters': cluster_data})}"
|
||||
)
|
||||
print(risk_analysis)
|
||||
|
||||
summary = await self.summary_agent.run(
|
||||
f"Synthesize insights from: {trend_analysis}, {risk_analysis}"
|
||||
)
|
||||
|
||||
print(summary)
|
||||
|
||||
volume_chart = self.visualizer.generate_volume_chart()
|
||||
network_graph = self.visualizer.generate_network_graph(
|
||||
self.wallet_analyzer
|
||||
)
|
||||
|
||||
return {
|
||||
"transactions": tx_data,
|
||||
"trend_analysis": trend_analysis,
|
||||
"risk_analysis": risk_analysis,
|
||||
"cluster_analysis": cluster_data,
|
||||
"summary": summary,
|
||||
"visualizations": {
|
||||
"volume_chart": volume_chart,
|
||||
"network_graph": network_graph,
|
||||
},
|
||||
}
|
||||
|
||||
async def run_continuous_analysis(self):
|
||||
logger.info("Starting continuous analysis")
|
||||
asyncio.create_task(self.start_websocket_stream())
|
||||
|
||||
while True:
|
||||
try:
|
||||
transactions = await self.fetch_transactions()
|
||||
if transactions:
|
||||
analysis = await self.analyze_transactions(
|
||||
transactions
|
||||
)
|
||||
timestamp = datetime.now().strftime(
|
||||
"%Y%m%d_%H%M%S"
|
||||
)
|
||||
with open(f"analysis_{timestamp}.json", "w") as f:
|
||||
json.dump(analysis, f, indent=2, default=str)
|
||||
logger.info(
|
||||
f"Analysis completed: analysis_{timestamp}.json"
|
||||
)
|
||||
await asyncio.sleep(60)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
# Add to __main__:
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting Solana analyzer...")
|
||||
analyzer = SolanaMultiAgentAnalyzer(alert_email="your@email.com")
|
||||
try:
|
||||
asyncio.run(analyzer.run_continuous_analysis())
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error: {e}")
|
@ -1,29 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""
|
||||
Represents a message with timestamp and optional metadata.
|
||||
|
||||
Usage
|
||||
--------------
|
||||
mes = Message(
|
||||
sender = "Kye",
|
||||
content = "message"
|
||||
)
|
||||
|
||||
print(mes)
|
||||
"""
|
||||
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
sender: str
|
||||
content: str
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
__repr__ means...
|
||||
"""
|
||||
return f"{self.timestamp} - {self.sender}: {self.content}"
|
@ -0,0 +1,15 @@
|
||||
from typing import Literal
|
||||
|
||||
# Literal of output types
|
||||
# Literal of output types
|
||||
OutputType = Literal[
|
||||
"all",
|
||||
"final",
|
||||
"list",
|
||||
"dict",
|
||||
".json",
|
||||
".md",
|
||||
".txt",
|
||||
".yaml",
|
||||
".toml",
|
||||
]
|
@ -1,511 +0,0 @@
|
||||
"""
|
||||
Todo
|
||||
- [ ] Test the new api feature
|
||||
- [ ] Add the agent schema for every agent -- following OpenAI assistaants schema
|
||||
- [ ] then add the swarm schema for the swarm url: /v1/swarms/{swarm_name}/agents/{agent_id}
|
||||
- [ ] Add the agent schema for the agent url: /v1/swarms/{swarm_name}/agents/{agent_id}
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
import tenacity
|
||||
|
||||
# from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.base_swarm import BaseSwarm
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
logger = initialize_logger("swarm-network")
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class TaskRequest(BaseModel):
|
||||
task: str
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class TaskResponse(BaseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class AgentInfo(BaseModel):
|
||||
agent_name: str
|
||||
agent_description: str
|
||||
|
||||
|
||||
class SwarmInfo(BaseModel):
|
||||
swarm_name: str
|
||||
swarm_description: str
|
||||
agents: List[AgentInfo]
|
||||
|
||||
|
||||
# Helper function to get the number of workers
|
||||
def get_number_of_workers():
|
||||
return multiprocessing.cpu_count()
|
||||
|
||||
|
||||
# [TODO] Add the agent schema for every agent -- following OpenAI assistaants schema
|
||||
class SwarmNetwork(BaseSwarm):
|
||||
"""
|
||||
SwarmNetwork class
|
||||
|
||||
The SwarmNetwork class is responsible for managing the agents pool
|
||||
and the task queue. It also monitors the health of the agents and
|
||||
scales the pool up or down based on the number of pending tasks
|
||||
and the current load of the agents.
|
||||
|
||||
For example, if the number of pending tasks is greater than the
|
||||
number of agents in the pool, the SwarmNetwork will scale up the
|
||||
pool by adding new agents. If the number of pending tasks is less
|
||||
than the number of agents in the pool, the SwarmNetwork will scale
|
||||
down the pool by removing agents.
|
||||
|
||||
The SwarmNetwork class also provides a simple API for interacting
|
||||
with the agents pool. The API is implemented using the Flask
|
||||
framework and is enabled by default. The API can be disabled by
|
||||
setting the `api_enabled` parameter to False.
|
||||
|
||||
Features:
|
||||
- Agent pool management
|
||||
- Task queue management
|
||||
- Agent health monitoring
|
||||
- Agent pool scaling
|
||||
- Simple API for interacting with the agent pool
|
||||
- Simple API for interacting with the task queue
|
||||
- Simple API for interacting with the agent health monitor
|
||||
- Simple API for interacting with the agent pool scaler
|
||||
- Create APIs for each agent in the pool (optional)
|
||||
- Run each agent on it's own thread
|
||||
- Run each agent on it's own process
|
||||
- Run each agent on it's own container
|
||||
- Run each agent on it's own machine
|
||||
- Run each agent on it's own cluster
|
||||
|
||||
|
||||
Attributes:
|
||||
task_queue (queue.Queue): A queue for storing tasks.
|
||||
idle_threshold (float): The idle threshold for the agents.
|
||||
busy_threshold (float): The busy threshold for the agents.
|
||||
agents (List[Agent]): A list of agents in the pool.
|
||||
api_enabled (bool): A flag to enable/disable the API.
|
||||
logging_enabled (bool): A flag to enable/disable logging.
|
||||
|
||||
Example:
|
||||
>>> from swarms.structs.agent import Agent
|
||||
>>> from swarms.structs.swarm_net import SwarmNetwork
|
||||
>>> agent = Agent()
|
||||
>>> swarm = SwarmNetwork(agents=[agent])
|
||||
>>> swarm.add_task("task")
|
||||
>>> swarm.run()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
agents: List[Agent] = None,
|
||||
idle_threshold: float = 0.2,
|
||||
busy_threshold: float = 0.7,
|
||||
api_enabled: Optional[bool] = False,
|
||||
logging_enabled: Optional[bool] = False,
|
||||
api_on: Optional[bool] = False,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
swarm_callable: Optional[callable] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(agents=agents, *args, **kwargs)
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.agents = agents
|
||||
self.task_queue = queue.Queue()
|
||||
self.idle_threshold = idle_threshold
|
||||
self.busy_threshold = busy_threshold
|
||||
self.lock = threading.Lock()
|
||||
self.api_enabled = api_enabled
|
||||
self.logging_enabled = logging_enabled
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.swarm_callable = swarm_callable
|
||||
|
||||
# Ensure that the agents list is not empty
|
||||
if not agents:
|
||||
raise ValueError("The agents list cannot be empty")
|
||||
|
||||
# Create a dictionary of agents for easy access
|
||||
self.agent_dict = {agent.id: agent for agent in agents}
|
||||
|
||||
# # Create the FastAPI instance
|
||||
# if api_on is True:
|
||||
# logger.info("Creating FastAPI instance")
|
||||
# self.app = FastAPI(debug=True, *args, **kwargs)
|
||||
|
||||
# self.app.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=["*"],
|
||||
# allow_credentials=True,
|
||||
# allow_methods=["*"],
|
||||
# allow_headers=["*"],
|
||||
# )
|
||||
|
||||
# logger.info("Routes set for creation")
|
||||
# self._create_routes()
|
||||
|
||||
def add_task(self, task):
|
||||
"""Add task to the task queue
|
||||
|
||||
Args:
|
||||
task (_type_): _description_
|
||||
|
||||
Example:
|
||||
>>> from swarms.structs.agent import Agent
|
||||
>>> from swarms.structs.swarm_net import SwarmNetwork
|
||||
>>> agent = Agent()
|
||||
>>> swarm = SwarmNetwork(agents=[agent])
|
||||
>>> swarm.add_task("task")
|
||||
"""
|
||||
self.logger.info(f"Adding task {task} to queue")
|
||||
try:
|
||||
self.task_queue.put(task)
|
||||
self.logger.info(f"Task {task} added to queue")
|
||||
except Exception as error:
|
||||
print(
|
||||
f"Error adding task to queue: {error} try again with"
|
||||
" a new task"
|
||||
)
|
||||
raise error
|
||||
|
||||
async def async_add_task(self, task):
|
||||
"""Add task to the task queue
|
||||
|
||||
Args:
|
||||
task (_type_): _description_
|
||||
|
||||
Example:
|
||||
>>> from swarms.structs.agent import Agent
|
||||
>>> from swarms.structs.swarm_net import SwarmNetwork
|
||||
>>> agent = Agent()
|
||||
>>> swarm = SwarmNetwork(agents=[agent])
|
||||
>>> swarm.add_task("task")
|
||||
|
||||
"""
|
||||
self.logger.info(
|
||||
f"Adding task {task} to queue asynchronously"
|
||||
)
|
||||
try:
|
||||
# Add task to queue asynchronously with asyncio
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None, self.task_queue.put, task
|
||||
)
|
||||
self.logger.info(f"Task {task} added to queue")
|
||||
except Exception as error:
|
||||
print(
|
||||
f"Error adding task to queue: {error} try again with"
|
||||
" a new task"
|
||||
)
|
||||
raise error
|
||||
|
||||
# def _create_routes(self) -> None:
|
||||
# """
|
||||
# Creates the routes for the API.
|
||||
# """
|
||||
# # Extensive logginbg
|
||||
# logger.info("Creating routes for the API")
|
||||
|
||||
# # Routes available
|
||||
# logger.info(
|
||||
# "Routes available: /v1/swarms, /v1/health, /v1/swarms/{swarm_name}/agents/{agent_id}, /v1/swarms/{swarm_name}/run"
|
||||
# )
|
||||
|
||||
# @self.app.get("/v1/swarms", response_model=SwarmInfo)
|
||||
# async def get_swarms() -> SwarmInfo:
|
||||
# try:
|
||||
# logger.info("Getting swarm information")
|
||||
# return SwarmInfo(
|
||||
# swarm_name=self.swarm_name,
|
||||
# swarm_description=self.swarm_description,
|
||||
# agents=[
|
||||
# AgentInfo(
|
||||
# agent_name=agent.agent_name,
|
||||
# agent_description=agent.agent_description,
|
||||
# )
|
||||
# for agent in self.agents
|
||||
# ],
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error getting swarm information: {str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=500, detail="Internal Server Error"
|
||||
# )
|
||||
|
||||
# @self.app.get("/v1/health")
|
||||
# async def get_health() -> Dict[str, str]:
|
||||
# try:
|
||||
# logger.info("Checking health status")
|
||||
# return {"status": "healthy"}
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error checking health status: {str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=500, detail="Internal Server Error"
|
||||
# )
|
||||
|
||||
# @self.app.get(f"/v1/swarms/{self.swarm_name}/agents/{{agent_id}}")
|
||||
# async def get_agent_info(agent_id: str) -> AgentInfo:
|
||||
# try:
|
||||
# logger.info(f"Getting information for agent {agent_id}")
|
||||
# agent = self.agent_dict.get(agent_id)
|
||||
# if not agent:
|
||||
# raise HTTPException(
|
||||
# status_code=404, detail="Agent not found"
|
||||
# )
|
||||
# return AgentInfo(
|
||||
# agent_name=agent.agent_name,
|
||||
# agent_description=agent.agent_description,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error getting agent information: {str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=500, detail="Internal Server Error"
|
||||
# )
|
||||
|
||||
# @self.app.post(
|
||||
# f"/v1/swarms/{self.swarm_name}/agents/{{agent_id}}/run",
|
||||
# response_model=TaskResponse,
|
||||
# )
|
||||
# async def run_agent_task(
|
||||
# task_request: TaskRequest,
|
||||
# ) -> TaskResponse:
|
||||
# try:
|
||||
# logger.info("Running agent task")
|
||||
# # Assuming only one agent in the swarm for this example
|
||||
# agent = self.agents[0]
|
||||
# logger.info(f"Running agent task: {task_request.task}")
|
||||
# result = agent.run(task_request.task)
|
||||
# return TaskResponse(result=result)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error running agent task: {str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=500, detail="Internal Server Error"
|
||||
# )
|
||||
|
||||
# def get_app(self) -> FastAPI:
|
||||
# """
|
||||
# Returns the FastAPI instance.
|
||||
|
||||
# Returns:
|
||||
# FastAPI: The FastAPI instance.
|
||||
# """
|
||||
# return self.app
|
||||
|
||||
def run_single_agent(
|
||||
self, agent_id, task: Optional[str], *args, **kwargs
|
||||
):
|
||||
"""Run agent the task on the agent id
|
||||
|
||||
Args:
|
||||
agent_id (_type_): _description_
|
||||
task (str, optional): _description_. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
self.logger.info(f"Running task {task} on agent {agent_id}")
|
||||
try:
|
||||
for agent in self.agents:
|
||||
if agent.id == agent_id:
|
||||
out = agent.run(task, *args, **kwargs)
|
||||
return out
|
||||
except Exception as error:
|
||||
self.logger.error(f"Error running task on agent: {error}")
|
||||
raise error
|
||||
|
||||
def run_many_agents(
|
||||
self, task: Optional[str] = None, *args, **kwargs
|
||||
) -> List:
|
||||
"""Run the task on all agents
|
||||
|
||||
Args:
|
||||
task (str, optional): _description_. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List: _description_
|
||||
"""
|
||||
self.logger.info(f"Running task {task} on all agents")
|
||||
try:
|
||||
return [
|
||||
agent.run(task, *args, **kwargs)
|
||||
for agent in self.agents
|
||||
]
|
||||
except Exception as error:
|
||||
logger.error(f"Error running task on agents: {error}")
|
||||
raise error
|
||||
|
||||
def list_agents(self):
|
||||
"""List all agents."""
|
||||
self.logger.info("[Listing all active agents]")
|
||||
|
||||
try:
|
||||
# Assuming self.agents is a list of agent objects
|
||||
for agent in self.agents:
|
||||
self.logger.info(
|
||||
f"[Agent] [ID: {agent.id}] [Name:"
|
||||
f" {agent.agent_name}] [Description:"
|
||||
f" {agent.agent_description}] [Status: Running]"
|
||||
)
|
||||
except Exception as error:
|
||||
self.logger.error(f"Error listing agents: {error}")
|
||||
raise
|
||||
|
||||
def get_agent(self, agent_id):
|
||||
"""Get agent by id
|
||||
|
||||
Args:
|
||||
agent_id (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
self.logger.info(f"Getting agent {agent_id}")
|
||||
|
||||
try:
|
||||
for agent in self.agents:
|
||||
if agent.id == agent_id:
|
||||
return agent
|
||||
raise ValueError(f"No agent found with ID {agent_id}")
|
||||
except Exception as error:
|
||||
self.logger.error(f"Error getting agent: {error}")
|
||||
raise error
|
||||
|
||||
def add_agent(self, agent: Agent):
|
||||
"""Add agent to the agent pool
|
||||
|
||||
Args:
|
||||
agent (_type_): _description_
|
||||
"""
|
||||
self.logger.info(f"Adding agent {agent} to pool")
|
||||
try:
|
||||
self.agents.append(agent)
|
||||
except Exception as error:
|
||||
print(f"Error adding agent to pool: {error}")
|
||||
raise error
|
||||
|
||||
def remove_agent(self, agent_id):
|
||||
"""Remove agent from the agent pool
|
||||
|
||||
Args:
|
||||
agent_id (_type_): _description_
|
||||
"""
|
||||
self.logger.info(f"Removing agent {agent_id} from pool")
|
||||
try:
|
||||
for agent in self.agents:
|
||||
if agent.id == agent_id:
|
||||
self.agents.remove(agent)
|
||||
return
|
||||
raise ValueError(f"No agent found with ID {agent_id}")
|
||||
except Exception as error:
|
||||
print(f"Error removing agent from pool: {error}")
|
||||
raise error
|
||||
|
||||
async def async_remove_agent(self, agent_id):
|
||||
"""Remove agent from the agent pool
|
||||
|
||||
Args:
|
||||
agent_id (_type_): _description_
|
||||
"""
|
||||
self.logger.info(f"Removing agent {agent_id} from pool")
|
||||
try:
|
||||
# Remove agent from pool asynchronously with asyncio
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None, self.remove_agent, agent_id
|
||||
)
|
||||
except Exception as error:
|
||||
print(f"Error removing agent from pool: {error}")
|
||||
raise error
|
||||
|
||||
def scale_up(self, num_agents: int = 1):
|
||||
"""Scale up the agent pool
|
||||
|
||||
Args:
|
||||
num_agents (int, optional): _description_. Defaults to 1.
|
||||
"""
|
||||
self.logger.info(f"Scaling up agent pool by {num_agents}")
|
||||
try:
|
||||
for _ in range(num_agents):
|
||||
self.agents.append(Agent())
|
||||
except Exception as error:
|
||||
print(f"Error scaling up agent pool: {error}")
|
||||
raise error
|
||||
|
||||
def scale_down(self, num_agents: int = 1):
|
||||
"""Scale down the agent pool
|
||||
|
||||
Args:
|
||||
num_agents (int, optional): _description_. Defaults to 1.
|
||||
"""
|
||||
for _ in range(num_agents):
|
||||
self.agents.pop()
|
||||
|
||||
@tenacity.retry(
|
||||
wait=tenacity.wait_fixed(1),
|
||||
stop=tenacity.stop_after_attempt(3),
|
||||
retry=tenacity.retry_if_exception_type(Exception),
|
||||
)
|
||||
def run(self, *args, **kwargs):
|
||||
"""run the swarm network"""
|
||||
app = self.get_app()
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
|
||||
logger.info(
|
||||
f"Running the swarm network with {len(self.agents)} on {self.host}:{self.port}"
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
# workers=get_number_of_workers(),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return app
|
||||
except Exception as error:
|
||||
logger.error(f"Error running the swarm network: {error}")
|
||||
raise error
|
||||
|
||||
|
||||
# # # Example usage
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# agent1 = Agent(
|
||||
# agent_name="Covid-19-Chat",
|
||||
# agent_description="This agent provides information about COVID-19 symptoms.",
|
||||
# llm=OpenAIChat(),
|
||||
# max_loops="auto",
|
||||
# autosave=True,
|
||||
# verbose=True,
|
||||
# stopping_condition="finish",
|
||||
# )
|
||||
|
||||
# agents = [agent1] # Add more agents as needed
|
||||
# swarm_name = "HealthSwarm"
|
||||
# swarm_description = (
|
||||
# "A swarm of agents providing health-related information."
|
||||
# )
|
||||
|
||||
# agent_api = SwarmNetwork(swarm_name, swarm_description, agents)
|
||||
# agent_api.run()
|
@ -0,0 +1,146 @@
|
||||
"""
|
||||
Package installation utility that checks for package existence and installs if needed.
|
||||
Supports both pip and conda package managers.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Literal, Optional, Union
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
import pkg_resources
|
||||
|
||||
|
||||
logger = initialize_logger("autocheckpackages")
|
||||
|
||||
|
||||
def check_and_install_package(
|
||||
package_name: str,
|
||||
package_manager: Literal["pip", "conda"] = "pip",
|
||||
version: Optional[str] = None,
|
||||
upgrade: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a package is installed and install it if not found.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package to check/install
|
||||
package_manager: Package manager to use ('pip' or 'conda')
|
||||
version: Specific version to install (optional)
|
||||
upgrade: Whether to upgrade the package if it exists
|
||||
|
||||
Returns:
|
||||
bool: True if package is available after check/install, False if installation failed
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid package manager is specified
|
||||
"""
|
||||
try:
|
||||
# Check if package exists
|
||||
if package_manager == "pip":
|
||||
try:
|
||||
pkg_resources.get_distribution(package_name)
|
||||
if not upgrade:
|
||||
logger.info(
|
||||
f"Package {package_name} is already installed"
|
||||
)
|
||||
return True
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
|
||||
# Construct installation command
|
||||
cmd = [sys.executable, "-m", "pip", "install"]
|
||||
if upgrade:
|
||||
cmd.append("--upgrade")
|
||||
|
||||
if version:
|
||||
cmd.append(f"{package_name}=={version}")
|
||||
else:
|
||||
cmd.append(package_name)
|
||||
|
||||
elif package_manager == "conda":
|
||||
# Check if conda is available
|
||||
try:
|
||||
subprocess.run(
|
||||
["conda", "--version"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
logger.error(
|
||||
"Conda is not available. Please install conda first."
|
||||
)
|
||||
return False
|
||||
|
||||
# Construct conda command
|
||||
cmd = ["conda", "install", "-y"]
|
||||
if version:
|
||||
cmd.append(f"{package_name}={version}")
|
||||
else:
|
||||
cmd.append(package_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid package manager: {package_manager}"
|
||||
)
|
||||
|
||||
# Run installation
|
||||
logger.info(f"Installing {package_name}...")
|
||||
subprocess.run(
|
||||
cmd, check=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
# Verify installation
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
logger.info(f"Successfully installed {package_name}")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.error(
|
||||
f"Package {package_name} was installed but cannot be imported"
|
||||
)
|
||||
return False
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to install {package_name}: {e.stderr}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error while installing {package_name}: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def auto_check_and_download_package(
|
||||
packages: Union[str, list[str]],
|
||||
package_manager: Literal["pip", "conda"] = "pip",
|
||||
upgrade: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure multiple packages are installed.
|
||||
|
||||
Args:
|
||||
packages: Single package name or list of package names
|
||||
package_manager: Package manager to use ('pip' or 'conda')
|
||||
upgrade: Whether to upgrade existing packages
|
||||
|
||||
Returns:
|
||||
bool: True if all packages are available, False if any installation failed
|
||||
"""
|
||||
if isinstance(packages, str):
|
||||
packages = [packages]
|
||||
|
||||
success = True
|
||||
for package in packages:
|
||||
if ":" in package:
|
||||
name, version = package.split(":")
|
||||
if not check_and_install_package(
|
||||
name, package_manager, version, upgrade
|
||||
):
|
||||
success = False
|
||||
else:
|
||||
if not check_and_install_package(
|
||||
package, package_manager, upgrade=upgrade
|
||||
):
|
||||
success = False
|
||||
|
||||
return success
|
@ -0,0 +1,263 @@
|
||||
"""
|
||||
Lazy Package Loader
|
||||
|
||||
This module provides utilities for lazy loading Python packages to improve startup time
|
||||
and reduce memory usage by only importing packages when they are actually used.
|
||||
|
||||
Features:
|
||||
- Type-safe lazy loading of packages
|
||||
- Support for nested module imports
|
||||
- Auto-completion support in IDEs
|
||||
- Thread-safe implementation
|
||||
- Comprehensive test coverage
|
||||
"""
|
||||
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
Optional,
|
||||
Dict,
|
||||
Any,
|
||||
Callable,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import importlib
|
||||
import functools
|
||||
import threading
|
||||
from importlib.util import find_spec
|
||||
from swarms.utils.auto_download_check_packages import (
|
||||
auto_check_and_download_package,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class ImportError(Exception):
|
||||
"""Raised when a lazy import fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LazyLoader:
|
||||
"""
|
||||
A thread-safe lazy loader for Python packages that only imports them when accessed.
|
||||
|
||||
Attributes:
|
||||
_module_name (str): The name of the module to be lazily loaded
|
||||
_module (Optional[ModuleType]): The cached module instance once loaded
|
||||
_lock (threading.Lock): Thread lock for safe concurrent access
|
||||
|
||||
Examples:
|
||||
>>> np = LazyLoader('numpy')
|
||||
>>> # numpy is not imported yet
|
||||
>>> result = np.array([1, 2, 3])
|
||||
>>> # numpy is imported only when first used
|
||||
"""
|
||||
|
||||
def __init__(self, module_name: str) -> None:
|
||||
"""
|
||||
Initialize the lazy loader with a module name.
|
||||
|
||||
Args:
|
||||
module_name: The fully qualified name of the module to lazily load
|
||||
|
||||
Raises:
|
||||
ImportError: If the module cannot be found in sys.path
|
||||
"""
|
||||
self._module_name = module_name
|
||||
self._module: Optional[ModuleType] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
auto_check_and_download_package(
|
||||
module_name, package_manager="pip"
|
||||
)
|
||||
|
||||
# Verify module exists without importing it
|
||||
if find_spec(module_name) is None:
|
||||
raise ImportError(
|
||||
f"Module '{module_name}' not found in sys.path"
|
||||
)
|
||||
|
||||
def _load_module(self) -> ModuleType:
|
||||
"""
|
||||
Thread-safe module loading.
|
||||
|
||||
Returns:
|
||||
ModuleType: The loaded module
|
||||
|
||||
Raises:
|
||||
ImportError: If module import fails
|
||||
"""
|
||||
if self._module is None:
|
||||
with self._lock:
|
||||
# Double-check pattern
|
||||
if self._module is None:
|
||||
try:
|
||||
self._module = importlib.import_module(
|
||||
self._module_name
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to import '{self._module_name}': {str(e)}"
|
||||
)
|
||||
return cast(ModuleType, self._module)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""
|
||||
Intercepts attribute access to load the module if needed.
|
||||
|
||||
Args:
|
||||
name: The attribute name being accessed
|
||||
|
||||
Returns:
|
||||
Any: The requested attribute from the loaded module
|
||||
|
||||
Raises:
|
||||
AttributeError: If the attribute doesn't exist in the module
|
||||
"""
|
||||
module = self._load_module()
|
||||
try:
|
||||
return getattr(module, name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"Module '{self._module_name}' has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
"""
|
||||
Returns list of attributes for autocomplete support.
|
||||
|
||||
Returns:
|
||||
List[str]: Available attributes in the module
|
||||
"""
|
||||
return dir(self._load_module())
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""
|
||||
Check if the module has been loaded.
|
||||
|
||||
Returns:
|
||||
bool: True if module is loaded, False otherwise
|
||||
"""
|
||||
return self._module is not None
|
||||
|
||||
|
||||
class LazyLoaderMetaclass(type):
|
||||
"""Metaclass to handle lazy loading behavior"""
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if hasattr(cls, "_lazy_loader"):
|
||||
return super().__call__(*args, **kwargs)
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
class LazyClassLoader:
|
||||
"""
|
||||
A descriptor that creates the actual class only when accessed,
|
||||
with proper inheritance support.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, class_name: str, bases: tuple, namespace: Dict[str, Any]
|
||||
):
|
||||
self.class_name = class_name
|
||||
self.bases = bases
|
||||
self.namespace = namespace
|
||||
self._real_class: Optional[Type] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _create_class(self) -> Type:
|
||||
"""Creates the actual class if it hasn't been created yet."""
|
||||
if self._real_class is None:
|
||||
with self._lock:
|
||||
if self._real_class is None:
|
||||
# Update namespace to include metaclass
|
||||
namespace = dict(self.namespace)
|
||||
namespace["__metaclass__"] = LazyLoaderMetaclass
|
||||
|
||||
# Create the class with metaclass
|
||||
new_class = LazyLoaderMetaclass(
|
||||
self.class_name, self.bases, namespace
|
||||
)
|
||||
|
||||
# Store reference to this loader
|
||||
new_class._lazy_loader = self
|
||||
self._real_class = new_class
|
||||
|
||||
return cast(Type, self._real_class)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Creates an instance of the lazy loaded class."""
|
||||
real_class = self._create_class()
|
||||
# Use the metaclass __call__ method
|
||||
return real_class(*args, **kwargs)
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Support for isinstance() checks"""
|
||||
real_class = self._create_class()
|
||||
return isinstance(instance, real_class)
|
||||
|
||||
def __subclasscheck__(self, subclass: Type) -> bool:
|
||||
"""Support for issubclass() checks"""
|
||||
real_class = self._create_class()
|
||||
return issubclass(subclass, real_class)
|
||||
|
||||
|
||||
def lazy_import(*names: str) -> Dict[str, LazyLoader]:
|
||||
"""
|
||||
Create multiple lazy loaders at once.
|
||||
|
||||
Args:
|
||||
*names: Module names to create lazy loaders for
|
||||
|
||||
Returns:
|
||||
Dict[str, LazyLoader]: Dictionary mapping module names to their lazy loaders
|
||||
|
||||
Examples:
|
||||
>>> modules = lazy_import('numpy', 'pandas', 'matplotlib.pyplot')
|
||||
>>> np = modules['numpy']
|
||||
>>> pd = modules['pandas']
|
||||
>>> plt = modules['matplotlib.pyplot']
|
||||
"""
|
||||
return {name.split(".")[-1]: LazyLoader(name) for name in names}
|
||||
|
||||
|
||||
def lazy_import_decorator(
|
||||
target: Union[Callable[..., T], Type[C]]
|
||||
) -> Union[Callable[..., T], Type[C], LazyClassLoader]:
|
||||
"""
|
||||
Enhanced decorator that supports both lazy imports and lazy class loading.
|
||||
"""
|
||||
if isinstance(target, type):
|
||||
# Store the original class details
|
||||
namespace = {
|
||||
name: value
|
||||
for name, value in target.__dict__.items()
|
||||
if not name.startswith("__")
|
||||
or name in ("__init__", "__new__")
|
||||
}
|
||||
|
||||
# Create lazy loader
|
||||
loader = LazyClassLoader(
|
||||
target.__name__, target.__bases__, namespace
|
||||
)
|
||||
|
||||
# Preserve class metadata
|
||||
loader.__module__ = target.__module__
|
||||
loader.__doc__ = target.__doc__
|
||||
|
||||
# Add reference to original class
|
||||
loader._original_class = target
|
||||
|
||||
return loader
|
||||
else:
|
||||
# Handle function decoration
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper
|
@ -1,73 +0,0 @@
|
||||
import os
|
||||
from loguru import logger
|
||||
import pygame
|
||||
import requests
|
||||
import tempfile
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class OpenAITTS:
|
||||
"""
|
||||
A class to interact with OpenAI API and play the generated audio with improved streaming capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.client = OpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), *args, **kwargs
|
||||
)
|
||||
pygame.init()
|
||||
|
||||
def run(
|
||||
self, task: str, play_sound: bool = True, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Run a task with the OpenAI API and optionally play the generated audio with improved streaming.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
play_sound (bool): If True, play the generated audio.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
response = self.client.audio.speech.create(
|
||||
model="tts-1",
|
||||
voice="nova",
|
||||
input=task,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
audio_url = response["url"]
|
||||
logger.info("Task completed successfully.")
|
||||
|
||||
if play_sound:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".mp3"
|
||||
) as tmp_file:
|
||||
with requests.get(audio_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
tmp_file.write(chunk)
|
||||
pygame.mixer.music.load(tmp_file.name)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
pygame.time.Clock().tick(10)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during task execution: {str(e)}")
|
||||
|
||||
|
||||
# client = OpenAITTS(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
# client.run("Hello world! This is a streaming test.", play_sound=True)
|
||||
|
||||
|
||||
def text_to_speech(
|
||||
task: str, play_sound: bool = True, *args, **kwargs
|
||||
):
|
||||
out = OpenAITTS().run(
|
||||
task, play_sound=play_sound, *args, **kwargs
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# print(text_to_speech(task="hello"))
|
@ -0,0 +1,42 @@
|
||||
from swarms.structs.tree_swarm import ForestSwarm, Tree, TreeAgent
|
||||
|
||||
|
||||
agents_tree1 = [
|
||||
TreeAgent(
|
||||
system_prompt="Stock Analysis Agent",
|
||||
agent_name="Stock Analysis Agent",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt="Financial Planning Agent",
|
||||
agent_name="Financial Planning Agent",
|
||||
),
|
||||
TreeAgent(
|
||||
agent_name="Retirement Strategy Agent",
|
||||
system_prompt="Retirement Strategy Agent",
|
||||
),
|
||||
]
|
||||
|
||||
agents_tree2 = [
|
||||
TreeAgent(
|
||||
system_prompt="Tax Filing Agent",
|
||||
agent_name="Tax Filing Agent",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt="Investment Strategy Agent",
|
||||
agent_name="Investment Strategy Agent",
|
||||
),
|
||||
TreeAgent(
|
||||
system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
|
||||
),
|
||||
]
|
||||
|
||||
# Create trees
|
||||
tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
|
||||
tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
|
||||
|
||||
# Create the ForestSwarm
|
||||
multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
|
||||
|
||||
# Run a task
|
||||
task = "Our company is incorporated in delaware, how do we do our taxes for free?"
|
||||
multi_agent_structure.run(task)
|
@ -0,0 +1,206 @@
|
||||
from swarms import Agent
|
||||
from loguru import logger
|
||||
import random
|
||||
import re
|
||||
|
||||
# Configure loguru
|
||||
logger.add("zkp_log.log", rotation="500 KB", retention="10 days", level="INFO")
|
||||
|
||||
|
||||
class ProverAgent:
|
||||
"""
|
||||
Prover Agent for Zero Knowledge Proof.
|
||||
|
||||
Responsibilities:
|
||||
- Generate commitments based on a secret.
|
||||
- Respond to challenges from the Verifier.
|
||||
|
||||
Attributes:
|
||||
agent (Agent): Swarms agent instance.
|
||||
p (int): The prime modulus.
|
||||
g (int): The generator.
|
||||
x (int): The Prover's secret.
|
||||
"""
|
||||
|
||||
def __init__(self, p: int, g: int, secret: int):
|
||||
self.p = p
|
||||
self.g = g
|
||||
self.x = secret # Prover's secret
|
||||
self.agent = Agent(
|
||||
agent_name="ProverAgent",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loop=1,
|
||||
interactive=False,
|
||||
streaming_on=True,
|
||||
system_prompt=(
|
||||
"You are the Prover in a Zero Knowledge Proof (ZKP) system. "
|
||||
"Your responsibilities are to generate commitments based on a secret value and "
|
||||
"respond to challenges from the Verifier without revealing the secret. "
|
||||
"Follow mathematical rules of modular arithmetic when performing computations."
|
||||
),
|
||||
)
|
||||
logger.info("Initialized ProverAgent with p={}, g={}, secret={}", p, g, secret)
|
||||
|
||||
def generate_commitment(self) -> tuple[int, int]:
|
||||
"""
|
||||
Generates a random commitment for the proof.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: The random value (r) and the commitment (t).
|
||||
"""
|
||||
r = random.randint(1, self.p - 2)
|
||||
task = (
|
||||
f"Compute the commitment t = g^r % p for g={self.g}, r={r}, p={self.p}. "
|
||||
"Return only the numerical value of t as an integer."
|
||||
)
|
||||
t = self.agent.run(task=task)
|
||||
t_value = self._extract_integer(t, "commitment")
|
||||
logger.info("Prover generated commitment: r={}, t={}", r, t_value)
|
||||
return r, t_value
|
||||
|
||||
def _extract_integer(self, response: str, label: str) -> int:
|
||||
"""
|
||||
Extracts an integer from the LLM response.
|
||||
|
||||
Args:
|
||||
response (str): The response from the agent.
|
||||
label (str): A label for logging purposes.
|
||||
|
||||
Returns:
|
||||
int: The extracted integer value.
|
||||
"""
|
||||
try:
|
||||
# Use regex to find the first integer in the response
|
||||
match = re.search(r"\b\d+\b", response)
|
||||
if match:
|
||||
value = int(match.group(0))
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"No integer found in {label} response: {response}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract integer from {label} response: {response}")
|
||||
raise ValueError(f"Invalid {label} response: {response}") from e
|
||||
|
||||
def respond_to_challenge(self, r: int, c: int) -> int:
|
||||
"""
|
||||
Computes the response to a challenge.
|
||||
|
||||
Args:
|
||||
r (int): The random value used in the commitment.
|
||||
c (int): The challenge issued by the Verifier.
|
||||
|
||||
Returns:
|
||||
int: The response (z).
|
||||
"""
|
||||
task = f"Compute the response z = (r + c * x) % (p-1) for r={r}, c={c}, x={self.x}, p={self.p}."
|
||||
z = self.agent.run(task=task)
|
||||
logger.info("Prover responded to challenge: z={}", z)
|
||||
return int(z)
|
||||
|
||||
|
||||
class VerifierAgent:
|
||||
"""
|
||||
Verifier Agent for Zero Knowledge Proof.
|
||||
|
||||
Responsibilities:
|
||||
- Issue challenges to the Prover.
|
||||
- Verify the Prover's response.
|
||||
|
||||
Attributes:
|
||||
agent (Agent): Swarms agent instance.
|
||||
p (int): The prime modulus.
|
||||
g (int): The generator.
|
||||
y (int): The public value from the Prover.
|
||||
"""
|
||||
|
||||
def __init__(self, p: int, g: int, y: int):
|
||||
self.p = p
|
||||
self.g = g
|
||||
self.y = y # Public value
|
||||
self.agent = Agent(
|
||||
agent_name="VerifierAgent",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loop=1,
|
||||
interactive=False,
|
||||
streaming_on=True,
|
||||
system_prompt=(
|
||||
"You are the Verifier in a Zero Knowledge Proof (ZKP) system. "
|
||||
"Your responsibilities are to issue random challenges and verify the Prover's response. "
|
||||
"Use modular arithmetic to check if the proof satisfies g^z % p == (t * y^c) % p."
|
||||
),
|
||||
)
|
||||
logger.info("Initialized VerifierAgent with p={}, g={}, y={}", p, g, y)
|
||||
|
||||
def issue_challenge(self) -> int:
|
||||
"""
|
||||
Issues a random challenge to the Prover.
|
||||
|
||||
Returns:
|
||||
int: The challenge value (c).
|
||||
"""
|
||||
c = random.randint(1, 10)
|
||||
logger.info("Verifier issued challenge: c={}", c)
|
||||
return c
|
||||
|
||||
def verify_proof(self, t: int, z: int, c: int) -> bool:
|
||||
"""
|
||||
Verifies the Prover's response.
|
||||
|
||||
Args:
|
||||
t (int): The commitment from the Prover.
|
||||
z (int): The response from the Prover.
|
||||
c (int): The challenge issued to the Prover.
|
||||
|
||||
Returns:
|
||||
bool: True if the proof is valid, False otherwise.
|
||||
"""
|
||||
task = f"Verify if g^z % p == (t * y^c) % p for g={self.g}, z={z}, p={self.p}, t={t}, y={self.y}, c={c}."
|
||||
verification_result = self.agent.run(task=task)
|
||||
is_valid = verification_result.strip().lower() == "true"
|
||||
logger.info("Verifier checked proof: t={}, z={}, c={}, valid={}", t, z, c, is_valid)
|
||||
return is_valid
|
||||
|
||||
|
||||
class CoordinatorAgent:
|
||||
"""
|
||||
Coordinator for orchestrating the Zero Knowledge Proof protocol.
|
||||
|
||||
Responsibilities:
|
||||
- Initialize parameters.
|
||||
- Facilitate interaction between Prover and Verifier agents.
|
||||
"""
|
||||
|
||||
def __init__(self, p: int, g: int, secret: int):
|
||||
self.p = p
|
||||
self.g = g
|
||||
self.prover = ProverAgent(p, g, secret)
|
||||
y = pow(g, secret, p) # Public value
|
||||
self.verifier = VerifierAgent(p, g, y)
|
||||
logger.info("Coordinator initialized with p={}, g={}, secret={}", p, g, secret)
|
||||
|
||||
def orchestrate(self) -> bool:
|
||||
"""
|
||||
Orchestrates the Zero Knowledge Proof protocol.
|
||||
|
||||
Returns:
|
||||
bool: True if the proof is valid, False otherwise.
|
||||
"""
|
||||
logger.info("Starting ZKP protocol orchestration.")
|
||||
r, t = self.prover.generate_commitment()
|
||||
c = self.verifier.issue_challenge()
|
||||
z = self.prover.respond_to_challenge(r, c)
|
||||
is_valid = self.verifier.verify_proof(t, z, c)
|
||||
logger.info("ZKP protocol completed. Valid proof: {}", is_valid)
|
||||
return is_valid
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example parameters
|
||||
p = 23 # Prime number
|
||||
g = 5 # Generator
|
||||
secret = 7 # Prover's secret
|
||||
|
||||
# Initialize the Coordinator and run the protocol
|
||||
coordinator = CoordinatorAgent(p, g, secret)
|
||||
result = coordinator.orchestrate()
|
||||
print(f"Zero Knowledge Proof Verification Result: {'Valid' if result else 'Invalid'}")
|
Loading…
Reference in new issue