diff --git a/api/agent_api.py b/api/agent_api.py index d1968d9d..922c4572 100644 --- a/api/agent_api.py +++ b/api/agent_api.py @@ -60,7 +60,7 @@ class AgentConfig(BaseModel): ..., description="System prompt for the agent" ) model_name: str = Field( - default="gpt-4", description="Model name to use" + default="gpt-4o-mini", description="Model name to use" ) temperature: float = Field( default=0.1, @@ -102,6 +102,14 @@ class AgentConfig(BaseModel): default_factory=list, description="Tags for categorizing the agent", ) + auto_generate_prompt: bool = Field( + default_factory=bool, + description="Auto generate a prompt based on the input", + ) + max_tokens: int = Field( + default_factory=int, + description="The number of max output tokens", + ) class AgentUpdate(BaseModel): @@ -197,9 +205,9 @@ class AgentStore: user_name=config.user_name, retry_attempts=config.retry_attempts, context_length=config.context_length, - return_step_meta=True, output_type="str", - streaming_on=config.streaming_on, + auto_generate_prompt=config.auto_generate_prompt, + max_tokens=config.max_tokens, ) agent_id = uuid4() @@ -441,6 +449,8 @@ class AgentStore: "agent_name": agent.agent_name, "model_name": agent.llm.model_name, "temperature": agent.llm.temperature, + "max_loops": agent.max_loops, + "context_window": agent.context_length, }, timestamp=datetime.utcnow(), processing_time=processing_time, diff --git a/api/test_api.py b/api/test_api.py new file mode 100644 index 00000000..1153d946 --- /dev/null +++ b/api/test_api.py @@ -0,0 +1,112 @@ +import requests +import json +from time import sleep + +BASE_URL = "http://api.swarms.ai:8000" + + +def make_request(method, endpoint, data=None): + """Helper function to make requests with error handling""" + url = f"{BASE_URL}{endpoint}" + try: + if method == "GET": + response = requests.get(url) + elif method == "POST": + response = requests.post(url, json=data) + elif method == "DELETE": + response = requests.delete(url) + + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print( + f"Error making {method} request to {endpoint}: {str(e)}" + ) + if hasattr(e.response, "text"): + print(f"Response text: {e.response.text}") + return None + + +def create_agent(): + """Create a test agent""" + data = { + "agent_name": "test_agent", + "model_name": "gpt-4", + "system_prompt": "You are a helpful assistant", + "description": "Test agent", + "temperature": 0.7, + "max_loops": 1, + "tags": ["test"], + } + return make_request("POST", "/v1/agent", data) + + +def list_agents(): + """List all agents""" + return make_request("GET", "/v1/agents") + + +def test_completion(agent_id): + """Test a completion with the agent""" + data = { + "prompt": "Say hello!", + "agent_id": agent_id, + "max_tokens": 100, + } + return make_request("POST", "/v1/agent/completions", data) + + +def get_agent_metrics(agent_id): + """Get metrics for an agent""" + return make_request("GET", f"/v1/agent/{agent_id}/metrics") + + +def delete_agent(agent_id): + """Delete an agent""" + return make_request("DELETE", f"/v1/agent/{agent_id}") + + +def run_tests(): + print("Starting API tests...") + + # Create an agent + print("\n1. Creating agent...") + agent_response = create_agent() + if not agent_response: + print("Failed to create agent") + return + + agent_id = agent_response.get("agent_id") + print(f"Created agent with ID: {agent_id}") + + # Give the server a moment to process + sleep(2) + + # List agents + print("\n2. Listing agents...") + agents = list_agents() + print(f"Found {len(agents)} agents") + + # Test completion + if agent_id: + print("\n3. Testing completion...") + completion = test_completion(agent_id) + if completion: + print( + f"Completion response: {completion.get('response')}" + ) + + print("\n4. Getting agent metrics...") + metrics = get_agent_metrics(agent_id) + if metrics: + print(f"Agent metrics: {json.dumps(metrics, indent=2)}") + + # Clean up + # print("\n5. Cleaning up - deleting agent...") + # delete_result = delete_agent(agent_id) + # if delete_result: + # print("Successfully deleted agent") + + +if __name__ == "__main__": + run_tests() diff --git a/byte.py b/byte.py deleted file mode 100644 index d0a5a92f..00000000 --- a/byte.py +++ /dev/null @@ -1,898 +0,0 @@ -from enum import Enum -from typing import Union, Optional -import io -from PIL import Image -import numpy as np -import torch -import struct - - -from enum import auto -from typing import List, Dict, Tuple -import wave -from dataclasses import dataclass -import torch.nn as nn -import torch.nn.functional as F -from loguru import logger -from einops import rearrange -from torch import Tensor - - -@dataclass -class ModelConfig: - """Configuration for the enhanced BytePredictor model.""" - - vocab_size: int = 256 # Standard byte range - hidden_size: int = 1024 - num_layers: int = 12 - num_key_value_heads: int = 8 # For multi-query attention - num_query_heads: int = 32 # More query heads than kv heads - dropout: float = 0.1 - max_sequence_length: int = 8192 - rope_theta: float = 10000.0 - layer_norm_eps: float = 1e-5 - vocab_parallel: bool = False - qk_norm: bool = True - qk_norm_scale: float = None - attention_bias: bool = False - - -class MultiQueryAttention(nn.Module): - """Fixed Multi-Query Attention implementation.""" - - def __init__(self, config: ModelConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.num_query_heads = config.num_query_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_query_heads - self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5) - - self.q_proj = nn.Linear( - config.hidden_size, config.num_query_heads * self.head_dim - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - ) - self.o_proj = nn.Linear( - config.num_query_heads * self.head_dim, config.hidden_size - ) - - self.qk_norm = config.qk_norm - if self.qk_norm: - self.q_norm = nn.LayerNorm(self.head_dim) - self.k_norm = nn.LayerNorm(self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - batch_size, seq_length, _ = hidden_states.shape - - # Project and reshape - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Reshape to [seq_len, batch, heads, head_dim] - q = q.view( - batch_size, - seq_length, - self.num_query_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - k = k.view( - batch_size, - seq_length, - self.num_key_value_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - v = v.view( - batch_size, - seq_length, - self.num_key_value_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - - # Apply rotary embeddings - # q, k = self.rotary(q, k, seq_length) - - # Apply QK normalization if enabled - if self.qk_norm: - q = self.q_norm(q) - k = self.k_norm(k) - - # Handle MQA head expansion - if self.num_key_value_heads != self.num_query_heads: - k = k.repeat_interleave( - self.num_query_heads // self.num_key_value_heads, - dim=2, - ) - v = v.repeat_interleave( - self.num_query_heads // self.num_key_value_heads, - dim=2, - ) - - # Compute attention - # Reshape for matmul: [batch, heads, seq_length, head_dim] - q = q.permute(1, 2, 0, 3) - k = k.permute(1, 2, 0, 3) - v = v.permute(1, 2, 0, 3) - - attn_weights = ( - torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = F.softmax(attn_weights, dim=-1) - - output = torch.matmul(attn_weights, v) - - # Reshape back to [batch, seq_length, hidden_size] - output = ( - output.transpose(1, 2) - .contiguous() - .view(batch_size, seq_length, -1) - ) - output = self.o_proj(output) - - return output - - -class EnhancedBytePredictor(nn.Module): - """Enhanced byte prediction model with state-of-the-art features.""" - - def __init__(self, config: ModelConfig): - super().__init__() - self.config = config - - # Token embeddings - self.tok_embeddings = nn.Embedding( - config.vocab_size, config.hidden_size - ) - - # Transformer layers - self.layers = nn.ModuleList( - [ - nn.ModuleDict( - { - "attention": MultiQueryAttention(config), - "attention_norm": nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ), - "feed_forward": nn.Sequential( - nn.Linear( - config.hidden_size, - 4 * config.hidden_size, - ), - nn.GELU(), - nn.Linear( - 4 * config.hidden_size, - config.hidden_size, - ), - ), - "feed_forward_norm": nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ), - } - ) - for _ in range(config.num_layers) - ] - ) - - self.norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.output = nn.Linear( - config.hidden_size, config.vocab_size, bias=False - ) - - # Initialize weights - self.apply(self._init_weights) - - def _init_weights(self, module: nn.Module) -> None: - """Initialize weights with scaled normal distribution.""" - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass of the model. - - Args: - input_ids: Tensor of shape (batch_size, sequence_length) - attention_mask: Optional attention mask - - Returns: - Tensor of logits with shape (batch_size, sequence_length, vocab_size) - """ - hidden_states = self.tok_embeddings(input_ids) - - # Create causal mask if needed - if attention_mask is None: - attention_mask = torch.triu( - torch.ones( - (input_ids.size(1), input_ids.size(1)), - device=input_ids.device, - dtype=torch.bool, - ), - diagonal=1, - ) - attention_mask = attention_mask.masked_fill( - attention_mask == 1, float("-inf") - ) - - # Apply transformer layers - for layer in self.layers: - # Attention block - hidden_states = hidden_states + layer["attention"]( - layer["attention_norm"](hidden_states), attention_mask - ) - - # Feed-forward block - hidden_states = hidden_states + layer["feed_forward"]( - layer["feed_forward_norm"](hidden_states) - ) - - hidden_states = self.norm(hidden_states) - logits = self.output(hidden_states) - - return logits - - def compute_loss( - self, - input_ids: torch.Tensor, - target_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Compute cross entropy loss. - - Args: - input_ids: Input token ids - target_ids: Target token ids - attention_mask: Optional attention mask - - Returns: - Loss value - """ - logits = self(input_ids, attention_mask) - loss = F.cross_entropy( - rearrange(logits, "b s v -> (b s) v"), - rearrange(target_ids, "b s -> (b s)"), - ) - return loss - - @torch.no_grad() - def _generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 100, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: float = 1.0, - ) -> torch.Tensor: - """ - Generate new tokens autoregressively. - - Args: - input_ids: Starting sequence - max_new_tokens: Number of tokens to generate - temperature: Sampling temperature - top_k: K for top-k sampling - top_p: P for nucleus sampling - repetition_penalty: Penalty for repeating tokens - - Returns: - Generated sequence - """ - batch_size, seq_length = input_ids.shape - generated = input_ids.clone() - - for _ in range(max_new_tokens): - if generated.size(1) >= self.config.max_sequence_length: - break - - # Forward pass - logits = self(generated)[:, -1, :] - - # Apply temperature - logits = logits / temperature - - # Apply repetition penalty - if repetition_penalty != 1.0: - for i in range(batch_size): - for token_id in set(generated[i].tolist()): - logits[i, token_id] /= repetition_penalty - - # Apply top-k sampling - if top_k is not None: - indices_to_remove = ( - logits - < torch.topk(logits, top_k)[0][..., -1, None] - ) - logits[indices_to_remove] = float("-inf") - - # Apply nucleus (top-p) sampling - if top_p is not None: - sorted_logits, sorted_indices = torch.sort( - logits, descending=True - ) - cumulative_probs = torch.cumsum( - F.softmax(sorted_logits, dim=-1), dim=-1 - ) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = ( - sorted_indices_to_remove[..., :-1].clone() - ) - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = torch.zeros_like( - logits, dtype=torch.bool - ) - indices_to_remove.scatter_( - 1, sorted_indices, sorted_indices_to_remove - ) - logits[indices_to_remove] = float("-inf") - - # Sample next token - probs = F.softmax(logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - - # Append to sequence - generated = torch.cat([generated, next_token], dim=1) - - return generated - - def generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 100, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: float = 1.0, - ): - tensor_data = self._generate( - input_ids=input_ids, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - ) - - return tensor_to_data(tensor_data) - - -# import torch -# from typing import Optional - - -class DataType(Enum): - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - BINARY = "binary" - - -class ByteDetokenizer: - """Utility class for converting model output bytes back to original data formats.""" - - @staticmethod - def tensor_to_bytes(tensor: torch.Tensor) -> bytes: - """Convert model output tensor to bytes.""" - # Convert logits/probabilities to byte values - if tensor.dim() > 1: - # If we have logits, convert to byte indices - byte_indices = tensor.argmax(dim=-1) - else: - byte_indices = tensor - - # Convert to Python bytes - return bytes( - byte_indices.cpu().numpy().astype(np.uint8).tolist() - ) - - @staticmethod - def decode_text(byte_sequence: bytes) -> str: - """Convert bytes to text.""" - try: - return byte_sequence.decode("utf-8") - except UnicodeDecodeError: - # Try with error handling - return byte_sequence.decode("utf-8", errors="replace") - - @staticmethod - def decode_image( - byte_sequence: bytes, - mode: str = "RGB", - size: Optional[tuple] = None, - ) -> Image.Image: - """Convert bytes to image. - - Args: - byte_sequence: Raw image bytes - mode: Image mode (RGB, RGBA, L, etc.) - size: Optional tuple of (width, height) - """ - try: - # Try to load as-is first (for standard image formats) - img = Image.open(io.BytesIO(byte_sequence)) - if size: - img = img.resize(size) - return img - except: - # If failed, assume raw pixel data - if not size: - # Try to determine size from byte count - pixel_count = len(byte_sequence) // len(mode) - size = ( - int(np.sqrt(pixel_count)), - int(np.sqrt(pixel_count)), - ) - - # Convert raw bytes to pixel array - pixels = np.frombuffer(byte_sequence, dtype=np.uint8) - pixels = pixels.reshape((*size, len(mode))) - - return Image.fromarray(pixels, mode=mode) - - @staticmethod - def decode_audio( - byte_sequence: bytes, - sample_rate: int = 44100, - channels: int = 2, - sample_width: int = 2, - ) -> np.ndarray: - """Convert bytes to audio samples. - - Args: - byte_sequence: Raw audio bytes - sample_rate: Audio sample rate in Hz - channels: Number of audio channels - sample_width: Bytes per sample (1, 2, or 4) - """ - # Determine format string based on sample width - format_str = { - 1: "b", # signed char - 2: "h", # short - 4: "i", # int - }[sample_width] - - # Unpack bytes to samples - sample_count = len(byte_sequence) // (channels * sample_width) - samples = struct.unpack( - f"<{sample_count * channels}{format_str}", byte_sequence - ) - - # Reshape to [samples, channels] - return np.array(samples).reshape(-1, channels) - - def decode_data( - self, - model_output: Union[torch.Tensor, bytes], - data_type: DataType, - **kwargs, - ) -> Union[str, Image.Image, np.ndarray, bytes]: - """Main method to decode model output to desired format. - - Args: - model_output: Either tensor from model or raw bytes - data_type: Type of data to decode to - **kwargs: Additional parameters for specific decoders - - Returns: - Decoded data in specified format - """ - # Convert tensor to bytes if needed - if isinstance(model_output, torch.Tensor): - byte_sequence = self.tensor_to_bytes(model_output) - else: - byte_sequence = model_output - - # Decode based on type - if data_type == DataType.TEXT: - return self.decode_text(byte_sequence) - elif data_type == DataType.IMAGE: - return self.decode_image(byte_sequence, **kwargs) - elif data_type == DataType.AUDIO: - return self.decode_audio(byte_sequence, **kwargs) - elif data_type == DataType.VIDEO: - raise NotImplementedError( - "Video decoding not yet implemented" - ) - else: # BINARY - return byte_sequence - - -# Usage example - - -class Modality(Enum): - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - BINARY = auto() - MULTIMODAL = auto() - - -@dataclass -class ModalityInfo: - """Information about detected modality.""" - - modality: Modality - confidence: float - metadata: Dict[str, any] - sub_modalities: Optional[List["ModalityInfo"]] = None - - -class ModalityDetector: - """Detects data modalities from byte sequences.""" - - # Common file signatures (magic numbers) - SIGNATURES = { - # Images - b"\xFF\xD8\xFF": "JPEG", - b"\x89PNG\r\n\x1a\n": "PNG", - b"GIF87a": "GIF", - b"GIF89a": "GIF", - b"RIFF": "WEBP", - # Audio - b"RIFF....WAVE": "WAV", - b"ID3": "MP3", - b"\xFF\xFB": "MP3", - b"OggS": "OGG", - # Video - b"\x00\x00\x00\x18ftypmp42": "MP4", - b"\x00\x00\x00\x1Cftypav01": "MP4", - b"\x1A\x45\xDF\xA3": "WEBM", - } - - def __init__(self): - self.magic = magic.Magic(mime=True) - - def _check_text_probability(self, data: bytes) -> float: - """Estimate probability that data is text.""" - # Check if data is valid UTF-8 - try: - data.decode("utf-8") - # Count printable ASCII characters - printable = sum(1 for b in data if 32 <= b <= 126) - return printable / len(data) - except UnicodeDecodeError: - return 0.0 - - def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]: - """Check if data is a valid image and extract metadata.""" - try: - with io.BytesIO(data) as bio: - img = Image.open(bio) - return True, { - "format": img.format, - "size": img.size, - "mode": img.mode, - } - except: - return False, {} - - def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]: - """Check if data is valid audio and extract metadata.""" - try: - with io.BytesIO(data) as bio: - # Try to parse as WAV - with wave.open(bio) as wav: - return True, { - "channels": wav.getnchannels(), - "sample_width": wav.getsampwidth(), - "framerate": wav.getframerate(), - "frames": wav.getnframes(), - } - except: - # Check for other audio signatures - for sig in [b"ID3", b"\xFF\xFB", b"OggS"]: - if data.startswith(sig): - return True, {"format": "compressed_audio"} - return False, {} - - def _detect_boundaries( - self, data: bytes - ) -> List[Tuple[int, int, Modality]]: - """Detect boundaries between different modalities in the data.""" - boundaries = [] - current_pos = 0 - - while current_pos < len(data): - # Look for known signatures - for sig, format_type in self.SIGNATURES.items(): - if data[current_pos:].startswith(sig): - # Found a signature, determine its length - if format_type in ["JPEG", "PNG", "GIF"]: - # Find image end - try: - with io.BytesIO( - data[current_pos:] - ) as bio: - img = Image.open(bio) - img.verify() - size = bio.tell() - boundaries.append( - ( - current_pos, - current_pos + size, - Modality.IMAGE, - ) - ) - current_pos += size - continue - except: - pass - - # Check for text sections - text_prob = self._check_text_probability( - data[current_pos : current_pos + 1024] - ) - if text_prob > 0.8: - # Look for end of text section - end_pos = current_pos + 1 - while end_pos < len(data): - if ( - self._check_text_probability( - data[end_pos : end_pos + 32] - ) - < 0.5 - ): - break - end_pos += 1 - boundaries.append( - (current_pos, end_pos, Modality.TEXT) - ) - current_pos = end_pos - continue - - current_pos += 1 - - return boundaries - - def detect_modality(self, data: bytes) -> ModalityInfo: - """Detect modality of byte sequence.""" - # First check for single modality - mime_type = self.magic.from_buffer(data) - - # Check text - text_prob = self._check_text_probability(data) - if text_prob > 0.9: - return ModalityInfo( - modality=Modality.TEXT, - confidence=text_prob, - metadata={"mime_type": mime_type}, - ) - - # Check image - is_image, image_meta = self._check_image_validity(data) - if is_image: - return ModalityInfo( - modality=Modality.IMAGE, - confidence=1.0, - metadata={**image_meta, "mime_type": mime_type}, - ) - - # Check audio - is_audio, audio_meta = self._check_audio_validity(data) - if is_audio: - return ModalityInfo( - modality=Modality.AUDIO, - confidence=1.0, - metadata={**audio_meta, "mime_type": mime_type}, - ) - - # Check for multimodal content - boundaries = self._detect_boundaries(data) - if len(boundaries) > 1: - sub_modalities = [] - for start, end, modality in boundaries: - chunk_data = data[start:end] - sub_info = self.detect_modality(chunk_data) - if sub_info.modality != Modality.BINARY: - sub_modalities.append(sub_info) - - if sub_modalities: - return ModalityInfo( - modality=Modality.MULTIMODAL, - confidence=0.8, - metadata={"mime_type": "multipart/mixed"}, - sub_modalities=sub_modalities, - ) - - # Default to binary - return ModalityInfo( - modality=Modality.BINARY, - confidence=0.5, - metadata={"mime_type": mime_type}, - ) - - def split_modalities( - self, data: bytes - ) -> List[Tuple[Modality, bytes, Dict]]: - """Split multimodal data into separate modalities.""" - boundaries = self._detect_boundaries(data) - result = [] - - for start, end, modality in boundaries: - chunk = data[start:end] - info = self.detect_modality(chunk) - result.append((modality, chunk, info.metadata)) - - return result - - -class AutoDetectBytesDecoder: - """Decoder that automatically detects and decodes different modalities.""" - - def __init__(self): - self.detector = ModalityDetector() - self.text_decoder = ByteDetokenizer() # From previous example - - def decode( - self, data: bytes - ) -> Union[str, Image.Image, np.ndarray, List[any]]: - """Automatically detect and decode byte sequence.""" - info = self.detector.detect_modality(data) - - if info.modality == Modality.MULTIMODAL: - # Handle multimodal content - parts = self.detector.split_modalities(data) - return [ - self.decode(chunk) for modality, chunk, _ in parts - ] - - if info.modality == Modality.TEXT: - return self.text_decoder.decode_text(data) - elif info.modality == Modality.IMAGE: - return self.text_decoder.decode_image(data) - elif info.modality == Modality.AUDIO: - return self.text_decoder.decode_audio(data) - else: - return data - - -# # Example usage -# def demo_auto_detection(): -# """Demonstrate auto modality detection.""" -# # Create mixed content -# text = "Hello, World!".encode('utf-8') - -# # Create a small test image -# img = Image.new('RGB', (100, 100), color='red') -# img_bytes = io.BytesIO() -# img.save(img_bytes, format='PNG') - -# # Combine into multimodal content -# mixed_content = text + img_bytes.getvalue() - -# # Initialize decoder -# decoder = AutoDetectBytesDecoder() - -# # Decode -# result = decoder.decode(mixed_content) - -# if isinstance(result, list): -# print("Detected multimodal content:") -# for i, part in enumerate(result): -# print(f"Part {i+1}: {type(part)}") - -# if __name__ == "__main__": -# demo_auto_detection() - - -def tensor_to_data(tensor: Tensor): - byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor) - - # Initialize auto-detector - decoder = AutoDetectBytesDecoder() - - # Decode with automatic detection - result = decoder.decode(byte_sequence) - - return result - - -def demo_byte_predictor(): - """Demo with smaller dimensions to test.""" - # Initialize model configuration with adjusted dimensions - config = ModelConfig( - vocab_size=256, - hidden_size=128, # Smaller for testing - num_layers=2, # Fewer layers for testing - num_key_value_heads=2, - num_query_heads=4, - dropout=0.1, - max_sequence_length=1024, - ) - - # Initialize model - model = EnhancedBytePredictor(config) - logger.info("Model initialized") - - # Move to GPU if available - device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - model = model.to(device) - logger.info(f"Using device: {device}") - - # Create sample input data - batch_size = 2 - seq_length = 16 # Shorter sequence for testing - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seq_length), device=device - ) - logger.info(f"Created input tensor of shape: {input_ids.shape}") - - # Test forward pass - try: - logits = model(input_ids) - logger.info( - f"Forward pass successful! Output shape: {logits.shape}" - ) - - # Test loss computation - target_ids = torch.randint( - 0, - config.vocab_size, - (batch_size, seq_length), - device=device, - ) - loss = model.compute_loss(input_ids, target_ids) - logger.info( - f"Loss computation successful! Loss value: {loss.item():.4f}" - ) - - # Test generation - prompt = torch.randint( - 0, - config.vocab_size, - (1, 4), # Very short prompt for testing - device=device, - ) - generated = model.generate( - prompt, max_new_tokens=8, temperature=0.8, top_k=50 - ) - logger.info( - f"Generation successful! Generated shape: {generated.shape}" - ) - - except Exception as e: - logger.error(f"Error during execution: {str(e)}") - raise - - -if __name__ == "__main__": - # Set up logging - # logger.remove() # Remove default handler - # logger.add(sys.stderr, format="{time:HH:mm:ss} | {level} | {message}") - - demo_byte_predictor() diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 53b4d273..fc3e0a4c 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -222,11 +222,11 @@ nav: - BaseMultiModalModel: "swarms/models/base_multimodal_model.md" - Multi Modal Models Available: "swarms/models/multimodal_models.md" - GPT4VisionAPI: "swarms/models/gpt4v.md" - - Swarms Cloud API: - # - Overview: "swarms_cloud/main.md" - - Overview: "swarms_cloud/vision.md" - - Swarms Cloud CLI: "swarms_cloud/cli.md" - - Add Agents to Marketplace: "swarms_cloud/add_agent.md" + # - Swarms Cloud API: + # # - Overview: "swarms_cloud/main.md" + # - Overview: "swarms_cloud/vision.md" + # - Swarms Cloud CLI: "swarms_cloud/cli.md" + # # - Add Agents to Marketplace: "swarms_cloud/add_agent.md" # - Available Models: "swarms_cloud/available_models.md" # - Agent API: "swarms_cloud/agent_api.md" # - Migrate from OpenAI to Swarms in 3 lines of code: "swarms_cloud/migrate_openai.md" diff --git a/example.py b/example.py index 7647d1cd..4f2d2f3f 100644 --- a/example.py +++ b/example.py @@ -1,28 +1,13 @@ -import os - -from dotenv import load_dotenv -from swarm_models import OpenAIChat - from swarms import Agent from swarms.prompts.finance_agent_sys_prompt import ( FINANCIAL_AGENT_SYS_PROMPT, ) -load_dotenv() - -# Get the OpenAI API key from the environment variable -api_key = os.getenv("OPENAI_API_KEY") - -# Create an instance of the OpenAIChat class -model = OpenAIChat( - openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 -) - # Initialize the agent agent = Agent( agent_name="Financial-Analysis-Agent", system_prompt=FINANCIAL_AGENT_SYS_PROMPT, - llm=model, + model_name="gpt-4o-mini", max_loops=1, autosave=True, dashboard=False, @@ -33,14 +18,10 @@ agent = Agent( retry_attempts=1, streaming_on=True, context_length=200000, - return_step_meta=True, - output_type="json", # "json", "dict", "csv" OR "string" soon "yaml" and + return_step_meta=False, + output_type="str", # "json", "dict", "csv" OR "string" soon "yaml" and auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task - artifacts_on=True, - artifacts_output_path="roth_ira_report", - artifacts_file_extension=".txt", max_tokens=8000, - return_history=True, ) diff --git a/new_features_examples/auto_agent.py b/new_features_examples/auto_agent.py index 712be089..7c7ee1d1 100644 --- a/new_features_examples/auto_agent.py +++ b/new_features_examples/auto_agent.py @@ -12,21 +12,31 @@ class DynamicParser: @staticmethod def extract_fields(model: Type[BaseModel]) -> Dict[str, Any]: return { - field_name: (field.annotation, ... if field.is_required() else None) + field_name: ( + field.annotation, + ... if field.is_required() else None, + ) for field_name, field in model.model_fields.items() } - + @staticmethod - def create_partial_model(model: Type[BaseModel], data: Dict[str, Any]) -> Type[BaseModel]: + def create_partial_model( + model: Type[BaseModel], data: Dict[str, Any] + ) -> Type[BaseModel]: fields = { - field_name: (field.annotation, ... if field.is_required() else None) + field_name: ( + field.annotation, + ... if field.is_required() else None, + ) for field_name, field in model.model_fields.items() if field_name in data } return create_model(f"Partial{model.__name__}", **fields) @classmethod - def parse(cls, data: Union[str, Dict[str, Any]], model: Type[BaseModel]) -> Optional[BaseModel]: + def parse( + cls, data: Union[str, Dict[str, Any]], model: Type[BaseModel] + ) -> Optional[BaseModel]: if isinstance(data, str): try: data = json.loads(data) @@ -47,25 +57,52 @@ class DynamicParser: load_dotenv() + # Define the Thoughts schema class Thoughts(BaseModel): - text: str = Field(..., description="Current thoughts or observations regarding the task.") - reasoning: str = Field(..., description="Logical reasoning behind the thought process.") - plan: str = Field(..., description="A short bulleted list that conveys the immediate and long-term plan.") - criticism: str = Field(..., description="Constructive self-criticism to improve future responses.") - speak: str = Field(..., description="A concise summary of thoughts intended for the user.") + text: str = Field( + ..., + description="Current thoughts or observations regarding the task.", + ) + reasoning: str = Field( + ..., + description="Logical reasoning behind the thought process.", + ) + plan: str = Field( + ..., + description="A short bulleted list that conveys the immediate and long-term plan.", + ) + criticism: str = Field( + ..., + description="Constructive self-criticism to improve future responses.", + ) + speak: str = Field( + ..., + description="A concise summary of thoughts intended for the user.", + ) + # Define the Command schema class Command(BaseModel): - name: str = Field(..., description="Command name to execute from the provided list of commands.") - args: Dict[str, Any] = Field(..., description="Arguments required to execute the command.") + name: str = Field( + ..., + description="Command name to execute from the provided list of commands.", + ) + args: Dict[str, Any] = Field( + ..., description="Arguments required to execute the command." + ) + # Define the AgentResponse schema class AgentResponse(BaseModel): - thoughts: Thoughts = Field(..., description="The agent's current thoughts and reasoning.") - command: Command = Field(..., description="The command to execute along with its arguments.") - - + thoughts: Thoughts = Field( + ..., description="The agent's current thoughts and reasoning." + ) + command: Command = Field( + ..., + description="The command to execute along with its arguments.", + ) + # Define tool functions def fluid_api_command(task: str): @@ -90,17 +127,26 @@ def do_nothing_command(): def task_complete_command(reason: str): """Mark the task as complete and provide a reason.""" print(f"Task completed: {reason}") - return {"status": "success", "message": f"Task completed: {reason}"} + return { + "status": "success", + "message": f"Task completed: {reason}", + } # Dynamic command execution def execute_command(name: str, args: Dict[str, Any]): """Dynamically execute a command based on its name and arguments.""" command_map: Dict[str, Callable] = { - "fluid_api": lambda **kwargs: fluid_api_command(task=kwargs.get("task")), - "send_tweet": lambda **kwargs: send_tweet_command(text=kwargs.get("text")), + "fluid_api": lambda **kwargs: fluid_api_command( + task=kwargs.get("task") + ), + "send_tweet": lambda **kwargs: send_tweet_command( + text=kwargs.get("text") + ), "do_nothing": lambda **kwargs: do_nothing_command(), - "task_complete": lambda **kwargs: task_complete_command(reason=kwargs.get("reason")), + "task_complete": lambda **kwargs: task_complete_command( + reason=kwargs.get("reason") + ), } if name not in command_map: @@ -110,23 +156,26 @@ def execute_command(name: str, args: Dict[str, Any]): return command_map[name](**args) -def parse_and_execute_command(response: Union[str, Dict[str, Any]], base_model: Type[BaseModel] = AgentResponse) -> Any: +def parse_and_execute_command( + response: Union[str, Dict[str, Any]], + base_model: Type[BaseModel] = AgentResponse, +) -> Any: """Enhanced command parser with flexible input handling""" parsed = DynamicParser.parse(response, base_model) if not parsed: raise ValueError("Failed to parse response") - - if hasattr(parsed, 'command'): + + if hasattr(parsed, "command"): command_name = parsed.command.name command_args = parsed.command.args return execute_command(command_name, command_args) - + return parsed ainame = "AutoAgent" userprovided = "assistant" - + SYSTEM_PROMPT = f""" You are {ainame}, an advanced and autonomous {userprovided}. Your role is to make decisions and complete tasks independently without seeking user assistance. Leverage your strengths as an LLM to solve tasks efficiently, adhering strictly to the commands and resources provided. @@ -174,7 +223,7 @@ model = OpenAIFunctionCaller( temperature=0.9, base_model=AgentResponse, # Pass the Pydantic schema as the base model parallel_tool_calls=False, - openai_api_key=os.getenv("OPENAI_API_KEY") + openai_api_key=os.getenv("OPENAI_API_KEY"), ) # Example usage diff --git a/new_features_examples/markdown_agent.py b/new_features_examples/markdown_agent.py new file mode 100644 index 00000000..51e15a97 --- /dev/null +++ b/new_features_examples/markdown_agent.py @@ -0,0 +1,8 @@ +from swarms import Agent + +Agent( + agent_name="Stock-Analysis-Agent", + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, +).run("What are 5 hft algorithms") diff --git a/pyproject.toml b/pyproject.toml index 5102f0d2..6df29882 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,8 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "6.4.7" -description = "Swarms - Pytorch" +version = "6.5.7" +description = "Swarms - TGSC" license = "MIT" authors = ["Kye Gomez "] homepage = "https://github.com/kyegomez/swarms" @@ -57,13 +57,13 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.10,<4.0" -torch = ">=2.1.1,<3.0" -transformers = ">= 4.39.0, <5.0.0" +# torch = ">=2.1.1,<3.0" +# transformers = ">= 4.39.0, <5.0.0" asyncio = ">=3.4.3,<4.0" toml = "*" pypdf = "4.3.1" loguru = "*" -pydantic = "2.8.2" +pydantic = ">=2.8.2<3.0" tenacity = "*" psutil = "*" sentry-sdk = {version = "*", extras = ["http"]} # Updated here @@ -73,12 +73,37 @@ docstring_parser = "0.16" tiktoken = "*" networkx = "*" aiofiles = "*" -swarm-models = "*" clusterops = "*" -chromadb = "*" +# chromadb = "*" reportlab = "*" doc-master = "*" rich = "*" +# sentence-transformers = "*" +swarm-models = "*" + + +# [tool.poetry.extras] +# # Extra for NLP-related functionalities +# nlp = [ +# "torch>=2.1.1,<3.0", +# "transformers>=4.39.0,<5.0.0", +# "sentence-transformers", +# "swarm-models", +# ] + +# # Extra for database-related functionalities +# db = ["chromadb"] + +# # All optional dependencies for convenience +# all = [ +# "torch>=2.1.1,<3.0", +# "transformers>=4.39.0,<5.0.0", +# "sentence-transformers", +# "chromadb", +# "swarm-models" +# ] + + [tool.poetry.scripts] swarms = "swarms.cli.main:main" diff --git a/real_time.py b/real_time.py new file mode 100644 index 00000000..fe55878d --- /dev/null +++ b/real_time.py @@ -0,0 +1,618 @@ +import torch +from torch.utils.data import DataLoader, TensorDataset +import numpy as np +from loguru import logger + +from dataclasses import dataclass +from typing import Optional, Tuple, Dict +import math +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +@dataclass +class TransformerConfig: + """Configuration class for MoE Transformer model parameters.""" + + vocab_size: int = 50257 + hidden_size: int = 768 + num_attention_heads: int = 12 + num_expert_layers: int = 4 + num_experts: int = 8 + expert_capacity: int = 32 + max_position_embeddings: int = 1024 + dropout_prob: float = 0.1 + layer_norm_epsilon: float = 1e-5 + initializer_range: float = 0.02 + num_query_groups: int = 4 # For multi-query attention + + +class ExpertLayer(nn.Module): + """Individual expert neural network.""" + + def __init__(self, config: TransformerConfig): + super().__init__() + self.fc1 = nn.Linear( + config.hidden_size, 4 * config.hidden_size + ) + self.fc2 = nn.Linear( + 4 * config.hidden_size, config.hidden_size + ) + self.activation = nn.GELU() + self.dropout = nn.Dropout(config.dropout_prob) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.fc2(x) + return x + + +class MixtureOfExperts(nn.Module): + """Mixture of Experts layer with dynamic routing.""" + + def __init__(self, config: TransformerConfig): + super().__init__() + self.num_experts = config.num_experts + self.expert_capacity = config.expert_capacity + + # Create expert networks + self.experts = nn.ModuleList( + [ExpertLayer(config) for _ in range(config.num_experts)] + ) + + # Router network + self.router = nn.Linear( + config.hidden_size, config.num_experts + ) + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: + """Route inputs to experts and combine outputs.""" + batch_size, seq_len, hidden_size = x.shape + + # Calculate routing probabilities + router_logits = self.router(x) + routing_weights = F.softmax(router_logits, dim=-1) + + # Select top-k experts + top_k = 2 + gates, indices = torch.topk(routing_weights, top_k, dim=-1) + gates = F.softmax(gates, dim=-1) + + # Process inputs through selected experts + final_output = torch.zeros_like(x) + router_load = torch.zeros(self.num_experts, device=x.device) + + for i in range(top_k): + expert_index = indices[..., i] + gate = gates[..., i : i + 1] + + # Count expert assignments + for j in range(self.num_experts): + router_load[j] += (expert_index == j).float().sum() + + # Process through selected experts + for j in range(self.num_experts): + mask = expert_index == j + if not mask.any(): + continue + + expert_input = x[mask] + expert_output = self.experts[j](expert_input) + final_output[mask] += gate[mask] * expert_output + + aux_loss = router_load.float().var() / ( + router_load.float().mean() ** 2 + ) + + return final_output, {"load_balancing_loss": aux_loss} + + +class MultiQueryAttention(nn.Module): + """Multi-Query Attention mechanism with proper multi-query group handling.""" + + def __init__(self, config: TransformerConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.num_query_groups = config.num_query_groups + self.hidden_size = config.hidden_size + self.head_dim = ( + config.hidden_size // config.num_attention_heads + ) + + # Query projection maintains full head dimension + self.q_proj = nn.Linear( + config.hidden_size, config.hidden_size + ) + + # Key and value projections use reduced number of heads (query groups) + self.k_proj = nn.Linear( + config.hidden_size, + self.head_dim * config.num_query_groups, + ) + self.v_proj = nn.Linear( + config.hidden_size, + self.head_dim * config.num_query_groups, + ) + + self.dropout = nn.Dropout(config.dropout_prob) + + # Calculate heads per group for proper reshaping + self.heads_per_group = ( + self.num_attention_heads // self.num_query_groups + ) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + cache: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + batch_size, seq_length, _ = hidden_states.shape + + # Project queries, keys, and values + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + # Reshape queries to full number of heads + queries = queries.view( + batch_size, + seq_length, + self.num_attention_heads, + self.head_dim, + ) + + # Reshape keys and values to number of query groups + keys = keys.view( + batch_size, + seq_length, + self.num_query_groups, + self.head_dim, + ) + values = values.view( + batch_size, + seq_length, + self.num_query_groups, + self.head_dim, + ) + + # Transpose for batch matrix multiplication + queries = queries.transpose( + 1, 2 + ) # (batch, n_heads, seq_len, head_dim) + keys = keys.transpose( + 1, 2 + ) # (batch, n_groups, seq_len, head_dim) + values = values.transpose( + 1, 2 + ) # (batch, n_groups, seq_len, head_dim) + + # Repeat keys and values for each head in the group + keys = keys.repeat_interleave(self.heads_per_group, dim=1) + values = values.repeat_interleave(self.heads_per_group, dim=1) + + # Compute attention scores + scale = 1.0 / math.sqrt(self.head_dim) + scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale + + if attention_mask is not None: + # Expand attention mask to match scores dimensions + expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2) + expanded_mask = expanded_mask.expand( + batch_size, + self.num_attention_heads, + seq_length, + seq_length, + ) + mask_value = torch.finfo(scores.dtype).min + attention_mask = expanded_mask.eq(0).float() * mask_value + scores = scores + attention_mask + + attention_weights = F.softmax(scores, dim=-1) + attention_weights = self.dropout(attention_weights) + + # Compute attention output + attention_output = torch.matmul(attention_weights, values) + attention_output = attention_output.transpose(1, 2) + attention_output = attention_output.reshape( + batch_size, seq_length, -1 + ) + + return attention_output, None + + +class MoETransformer(nn.Module): + """ + Production-grade Transformer model with Mixture of Experts and Multi-Query Attention. + + Features: + - Multi-Query Attention mechanism for efficient inference + - Mixture of Experts for dynamic routing and specialization + - Real-time weight updates based on input similarity + - Built-in logging and monitoring + - Type annotations for better code maintainability + """ + + def __init__(self, config: TransformerConfig): + super().__init__() + self.config = config + + # Initialize components + self.embedding = nn.Embedding( + config.vocab_size, config.hidden_size + ) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # Multi-Query Attention layers + self.attention_layers = nn.ModuleList( + [ + MultiQueryAttention(config) + for _ in range(config.num_expert_layers) + ] + ) + + # Mixture of Experts layers + self.moe_layers = nn.ModuleList( + [ + MixtureOfExperts(config) + for _ in range(config.num_expert_layers) + ] + ) + + # Layer normalization and dropout + self.layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_prob) + + # Output projection + self.output_projection = nn.Linear( + config.hidden_size, config.vocab_size + ) + + # Initialize weights + self.apply(self._init_weights) + logger.info("Initialized MoETransformer model") + + def _init_weights(self, module: nn.Module): + """Initialize model weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range + ) + if ( + isinstance(module, nn.Linear) + and module.bias is not None + ): + module.bias.data.zero_() + + def get_position_embeddings(self, position_ids: Tensor) -> Tensor: + """Generate position embeddings.""" + return self.position_embedding(position_ids) + + def forward( + self, + input_ids: Tensor, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + cache: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Dict]: + """ + Forward pass through the model. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask for padding + position_ids: Position IDs for positioning encoding + cache: Cache for key/value states in generation + + Returns: + tuple: (logits, auxiliary_outputs) + """ + batch_size, seq_length = input_ids.shape + + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device + ) + position_ids = position_ids.unsqueeze(0).expand_as( + input_ids + ) + + # Get embeddings + inputs_embeds = self.embedding(input_ids) + position_embeds = self.get_position_embeddings(position_ids) + hidden_states = inputs_embeds + position_embeds + + # Initialize auxiliary outputs + aux_outputs = {"moe_losses": []} + + # Process through transformer layers + for attention_layer, moe_layer in zip( + self.attention_layers, self.moe_layers + ): + # Multi-Query Attention + attention_output, _ = attention_layer( + hidden_states, attention_mask, cache + ) + hidden_states = self.layer_norm( + hidden_states + attention_output + ) + + # Mixture of Experts + moe_output, moe_aux = moe_layer(hidden_states) + hidden_states = self.layer_norm( + hidden_states + moe_output + ) + aux_outputs["moe_losses"].append( + moe_aux["load_balancing_loss"] + ) + + # Final output projection + logits = self.output_projection(hidden_states) + + return logits, aux_outputs + + def fetch_loss( + self, + logits: Tensor, + labels: Tensor, + aux_outputs: Dict, + reduction: str = "mean", + ) -> Tensor: + """ + Calculate the total loss including MoE balancing losses. + + Args: + logits: Model output logits + labels: Ground truth labels + aux_outputs: Auxiliary outputs from forward pass + reduction: Loss reduction method + + Returns: + Tensor: Total loss + """ + # Calculate cross entropy loss + ce_loss = F.cross_entropy( + logits.view(-1, self.config.vocab_size), + labels.view(-1), + reduction=reduction, + ) + + # Calculate MoE loss + moe_loss = torch.stack(aux_outputs["moe_losses"]).mean() + + # Combine losses + total_loss = ce_loss + 0.01 * moe_loss + + logger.debug( + f"CE Loss: {ce_loss.item():.4f}, " + f"MoE Loss: {moe_loss.item():.4f}" + ) + + return total_loss + + @torch.no_grad() + def generate( + self, + input_ids: Tensor, + max_length: int = 100, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 0.9, + ) -> Tensor: + """ + Generate text using the model. + + Args: + input_ids: Initial input tokens + max_length: Maximum sequence length to generate + temperature: Sampling temperature + top_k: Number of highest probability tokens to keep + top_p: Cumulative probability for nucleus sampling + + Returns: + Tensor: Generated token IDs + """ + batch_size = input_ids.shape[0] + device = input_ids.device + + # Initialize sequence with input_ids + generated = input_ids + + # Cache for key-value pairs + cache = {} + + for _ in range(max_length): + # Get position IDs for current sequence + position_ids = torch.arange( + generated.shape[1], dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).expand( + batch_size, -1 + ) + + # Forward pass + logits, _ = self.forward( + generated, position_ids=position_ids, cache=cache + ) + + # Get next token logits + next_token_logits = logits[:, -1, :] / temperature + + # Apply top-k filtering + if top_k > 0: + indices_to_remove = ( + next_token_logits + < torch.topk(next_token_logits, top_k)[0][ + ..., -1, None + ] + ) + next_token_logits[indices_to_remove] = float("-inf") + + # Apply top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort( + next_token_logits, descending=True + ) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = ( + sorted_indices_to_remove[..., :-1].clone() + ) + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[ + sorted_indices_to_remove + ] + next_token_logits[indices_to_remove] = float("-inf") + + # Sample next token + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + # Append next token to sequence + generated = torch.cat((generated, next_token), dim=1) + + # Check for end of sequence token + if (next_token == self.config.vocab_size - 1).all(): + break + + return generated + + +# Initialize model configuration +config = TransformerConfig( + vocab_size=50257, + hidden_size=768, + num_attention_heads=12, + num_expert_layers=4, + num_experts=8, + expert_capacity=32, + max_position_embeddings=1024, + num_query_groups=4, +) + + +def prepare_sample_data( + batch_size: int = 8, + seq_length: int = 512, + vocab_size: int = 50257, +) -> DataLoader: + """Create sample data for demonstration.""" + # Create random input sequences + input_ids = torch.randint( + 0, vocab_size, (100, seq_length) # 100 samples + ) + + # Create target sequences (shifted by 1) + labels = torch.randint(0, vocab_size, (100, seq_length)) + + # Create attention masks (1 for real tokens, 0 for padding) + attention_mask = torch.ones_like(input_ids) + + # Create dataset and dataloader + dataset = TensorDataset(input_ids, attention_mask, labels) + dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=True + ) + + return dataloader + + +def train_step( + model: MoETransformer, + batch: tuple, + optimizer: torch.optim.Optimizer, + device: str = "cuda" if torch.cuda.is_available() else "cpu", +) -> float: + """Execute single training step.""" + model.train() + optimizer.zero_grad() + + # Unpack batch + input_ids, attention_mask, labels = [b.to(device) for b in batch] + + # Forward pass + logits, aux_outputs = model( + input_ids=input_ids, attention_mask=attention_mask + ) + + # Calculate loss + loss = model.fetch_loss(logits, labels, aux_outputs) + + # Backward pass + loss.backward() + optimizer.step() + + return loss.item() + + +def main(): + # Set device + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device}") + + # Initialize model + model = MoETransformer(config).to(device) + logger.info("Model initialized") + + # Setup optimizer + optimizer = torch.optim.AdamW( + model.parameters(), lr=1e-4, weight_decay=0.01 + ) + + # Prepare data + dataloader = prepare_sample_data() + logger.info("Data prepared") + + # Training loop + num_epochs = 3 + for epoch in range(num_epochs): + epoch_losses = [] + + for batch_idx, batch in enumerate(dataloader): + loss = train_step(model, batch, optimizer, device) + epoch_losses.append(loss) + + if batch_idx % 10 == 0: + logger.info( + f"Epoch {epoch+1}/{num_epochs} " + f"Batch {batch_idx}/{len(dataloader)} " + f"Loss: {loss:.4f}" + ) + + avg_loss = np.mean(epoch_losses) + logger.info(f"Epoch {epoch+1} average loss: {avg_loss:.4f}") + + # Generation example + model.eval() + with torch.no_grad(): + # Prepare input prompt + prompt = torch.randint(0, config.vocab_size, (1, 10)).to( + device + ) + + # Generate sequence + generated = model.generate( + input_ids=prompt, + max_length=50, + temperature=0.7, + top_k=50, + top_p=0.9, + ) + + logger.info(f"Generated sequence shape: {generated.shape}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index e5375a0d..13ab894a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,6 @@ pytest>=8.1.1 pandas>=2.2.2 networkx aiofiles -swarm-models clusterops reportlab doc-master diff --git a/scripts/platform_update/parse_prompts_and_submit_to_marketplace 2.py b/scripts/platform_update/parse_prompts_and_submit_to_marketplace 2.py deleted file mode 100644 index e8685673..00000000 --- a/scripts/platform_update/parse_prompts_and_submit_to_marketplace 2.py +++ /dev/null @@ -1,121 +0,0 @@ -import json -import os -from difflib import SequenceMatcher - -import requests -from dotenv import load_dotenv -from loguru import logger -from supabase import Client, create_client - -load_dotenv() - -# Initialize Supabase client -SUPABASE_URL = os.getenv("SUPABASE_URL") -SUPABASE_KEY = os.getenv("SUPABASE_KEY") -supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) - -# Swarms API URL and headers -SWARMS_API_URL = "https://swarms.world/api/add-prompt" -SWARMS_API_KEY = os.getenv("SWARMS_API_KEY") -headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {SWARMS_API_KEY}", -} - -# Configure logger -logger.add( - "fetch_and_publish_prompts.log", rotation="1 MB" -) # Log file with rotation - - -def fetch_and_publish_prompts(): - logger.info("Starting to fetch and publish prompts.") - - # Fetch data from Supabase - try: - response = ( - supabase.table("swarms_framework_schema") - .select("*") - .execute() - ) - rows = response.data - logger.info(f"Fetched {len(rows)} rows from Supabase.") - except Exception as e: - logger.error(f"Failed to fetch data from Supabase: {e}") - return - - # Track published prompts to avoid duplicates - published_prompts = set() - - for row in rows: - # Extract agent_name and system_prompt - data = row.get("data", {}) - agent_name = data.get("agent_name") - system_prompt = data.get("system_prompt") - - # Skip if either is missing or duplicate - if not agent_name or not system_prompt: - logger.warning( - f"Skipping row due to missing agent_name or system_prompt: {row}" - ) - continue - if is_duplicate(system_prompt, published_prompts): - logger.info( - f"Skipping duplicate prompt for agent: {agent_name}" - ) - continue - - # Create the data payload for the marketplace - prompt_data = { - "name": f"{agent_name} - System Prompt", - "prompt": system_prompt, - "description": f"System prompt for agent {agent_name}.", - "useCases": extract_use_cases(system_prompt), - "tags": "agent, system-prompt", - } - - # Publish to the marketplace - try: - response = requests.post( - SWARMS_API_URL, - headers=headers, - data=json.dumps(prompt_data), - ) - if response.status_code == 200: - logger.info( - f"Successfully published prompt for agent: {agent_name}" - ) - published_prompts.add(system_prompt) - else: - logger.error( - f"Failed to publish prompt for agent: {agent_name}. Response: {response.text}" - ) - except Exception as e: - logger.error( - f"Exception occurred while publishing prompt for agent: {agent_name}. Error: {e}" - ) - - -def is_duplicate(new_prompt, published_prompts): - """Check if the prompt is a duplicate using semantic similarity.""" - for prompt in published_prompts: - similarity = SequenceMatcher(None, new_prompt, prompt).ratio() - if ( - similarity > 0.9 - ): # Threshold for considering prompts as duplicates - return True - return False - - -def extract_use_cases(prompt): - """Extract use cases from the prompt by chunking it into meaningful segments.""" - # This is a simple placeholder; you can use a more advanced method to extract use cases - chunks = [prompt[i : i + 50] for i in range(0, len(prompt), 50)] - return [ - {"title": f"Use case {idx+1}", "description": chunk} - for idx, chunk in enumerate(chunks) - ] - - -# Main execution -fetch_and_publish_prompts() diff --git a/scripts/platform_update/parse_prompts_and_submit_to_marketplace.py b/scripts/platform_update/parse_prompts_and_submit_to_marketplace.py deleted file mode 100644 index e8685673..00000000 --- a/scripts/platform_update/parse_prompts_and_submit_to_marketplace.py +++ /dev/null @@ -1,121 +0,0 @@ -import json -import os -from difflib import SequenceMatcher - -import requests -from dotenv import load_dotenv -from loguru import logger -from supabase import Client, create_client - -load_dotenv() - -# Initialize Supabase client -SUPABASE_URL = os.getenv("SUPABASE_URL") -SUPABASE_KEY = os.getenv("SUPABASE_KEY") -supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) - -# Swarms API URL and headers -SWARMS_API_URL = "https://swarms.world/api/add-prompt" -SWARMS_API_KEY = os.getenv("SWARMS_API_KEY") -headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {SWARMS_API_KEY}", -} - -# Configure logger -logger.add( - "fetch_and_publish_prompts.log", rotation="1 MB" -) # Log file with rotation - - -def fetch_and_publish_prompts(): - logger.info("Starting to fetch and publish prompts.") - - # Fetch data from Supabase - try: - response = ( - supabase.table("swarms_framework_schema") - .select("*") - .execute() - ) - rows = response.data - logger.info(f"Fetched {len(rows)} rows from Supabase.") - except Exception as e: - logger.error(f"Failed to fetch data from Supabase: {e}") - return - - # Track published prompts to avoid duplicates - published_prompts = set() - - for row in rows: - # Extract agent_name and system_prompt - data = row.get("data", {}) - agent_name = data.get("agent_name") - system_prompt = data.get("system_prompt") - - # Skip if either is missing or duplicate - if not agent_name or not system_prompt: - logger.warning( - f"Skipping row due to missing agent_name or system_prompt: {row}" - ) - continue - if is_duplicate(system_prompt, published_prompts): - logger.info( - f"Skipping duplicate prompt for agent: {agent_name}" - ) - continue - - # Create the data payload for the marketplace - prompt_data = { - "name": f"{agent_name} - System Prompt", - "prompt": system_prompt, - "description": f"System prompt for agent {agent_name}.", - "useCases": extract_use_cases(system_prompt), - "tags": "agent, system-prompt", - } - - # Publish to the marketplace - try: - response = requests.post( - SWARMS_API_URL, - headers=headers, - data=json.dumps(prompt_data), - ) - if response.status_code == 200: - logger.info( - f"Successfully published prompt for agent: {agent_name}" - ) - published_prompts.add(system_prompt) - else: - logger.error( - f"Failed to publish prompt for agent: {agent_name}. Response: {response.text}" - ) - except Exception as e: - logger.error( - f"Exception occurred while publishing prompt for agent: {agent_name}. Error: {e}" - ) - - -def is_duplicate(new_prompt, published_prompts): - """Check if the prompt is a duplicate using semantic similarity.""" - for prompt in published_prompts: - similarity = SequenceMatcher(None, new_prompt, prompt).ratio() - if ( - similarity > 0.9 - ): # Threshold for considering prompts as duplicates - return True - return False - - -def extract_use_cases(prompt): - """Extract use cases from the prompt by chunking it into meaningful segments.""" - # This is a simple placeholder; you can use a more advanced method to extract use cases - chunks = [prompt[i : i + 50] for i in range(0, len(prompt), 50)] - return [ - {"title": f"Use case {idx+1}", "description": chunk} - for idx, chunk in enumerate(chunks) - ] - - -# Main execution -fetch_and_publish_prompts() diff --git a/simple_example.py b/simple_example.py index 2fcbb8f9..1044958a 100644 --- a/simple_example.py +++ b/simple_example.py @@ -4,4 +4,6 @@ Agent( agent_name="Stock-Analysis-Agent", model_name="gpt-4o-mini", max_loops=1, + interactive=False, + streaming_on=True, ).run("What are 5 hft algorithms") diff --git a/sol_agent.py b/sol_agent.py new file mode 100644 index 00000000..09319d0e --- /dev/null +++ b/sol_agent.py @@ -0,0 +1,433 @@ +import asyncio +import json +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Dict, List, Optional, Set + +import aiohttp +import matplotlib.pyplot as plt +import networkx as nx +import websockets +from loguru import logger + +from swarms import Agent + +TREND_AGENT_PROMPT = """You are a specialized blockchain trend analysis agent. Your role: +1. Analyze transaction patterns in Solana blockchain data +2. Identify volume trends, price movements, and temporal patterns +3. Focus on whale movements and their market impact +4. Format findings in clear, structured JSON +5. Include confidence scores for each insight +6. Flag unusual patterns or anomalies +7. Provide historical context for significant movements + +Output format: +{ + "trends": [ + {"pattern": str, "confidence": float, "impact": str} + ], + "whale_activity": {...}, + "temporal_analysis": {...} +}""" + +RISK_AGENT_PROMPT = """You are a blockchain risk assessment specialist. Your tasks: +1. Identify suspicious transaction patterns +2. Monitor for known exploit signatures +3. Assess wallet clustering and relationship patterns +4. Evaluate transaction velocity and size anomalies +5. Check for bridge-related risks +6. Monitor smart contract interactions +7. Flag potential wash trading + +Output format: +{ + "risk_score": float, + "flags": [...], + "recommendations": [...] +}""" + +SUMMARY_AGENT_PROMPT = """You are a blockchain data synthesis expert. Your responsibilities: +1. Combine insights from trend and risk analyses +2. Prioritize actionable intelligence +3. Highlight critical patterns +4. Generate executive summaries +5. Provide market context +6. Make predictions with confidence intervals +7. Suggest trading strategies based on data + +Output format: +{ + "key_insights": [...], + "market_impact": str, + "recommendations": {...} +}""" + + +@dataclass +class Transaction: + signature: str + timestamp: datetime + amount: float + from_address: str + to_address: str + + +class SolanaRPC: + def __init__( + self, endpoint="https://api.mainnet-beta.solana.com" + ): + self.endpoint = endpoint + self.session = None + + async def get_signatures(self, address: str) -> List[Dict]: + if not self.session: + self.session = aiohttp.ClientSession() + + payload = { + "jsonrpc": "2.0", + "id": 1, + "method": "getSignaturesForAddress", + "params": [address, {"limit": 100}], + } + + async with self.session.post( + self.endpoint, json=payload + ) as response: + result = await response.json() + return result.get("result", []) + + async def get_transaction(self, signature: str) -> Dict: + payload = { + "jsonrpc": "2.0", + "id": 1, + "method": "getTransaction", + "params": [ + signature, + { + "encoding": "json", + "maxSupportedTransactionVersion": 0, + }, + ], + } + + async with self.session.post( + self.endpoint, json=payload + ) as response: + result = await response.json() + return result.get("result", {}) + + +class AlertSystem: + def __init__(self, email: str, threshold: float = 1000.0): + self.email = email + self.threshold = threshold + self.smtp_server = "smtp.gmail.com" + self.smtp_port = 587 + + async def check_and_alert( + self, transaction: Transaction, risk_score: float + ): + if transaction.amount > self.threshold or risk_score > 0.8: + await self.send_alert(transaction, risk_score) + + async def send_alert( + self, transaction: Transaction, risk_score: float + ): + # msg = MIMEText( + # f"High-risk transaction detected:\n" + # f"Amount: {transaction.amount} SOL\n" + # f"Risk Score: {risk_score}\n" + # f"Signature: {transaction.signature}" + # ) + logger.info( + f"Alert sent for transaction {transaction.signature}" + ) + + +class WalletClusterAnalyzer: + def __init__(self): + self.graph = nx.Graph() + self.known_wallets: Set[str] = set() + + def update_graph(self, transaction: Transaction): + self.graph.add_edge( + transaction.from_address, + transaction.to_address, + weight=transaction.amount, + ) + self.known_wallets.add(transaction.from_address) + self.known_wallets.add(transaction.to_address) + + def identify_clusters(self) -> Dict: + communities = nx.community.greedy_modularity_communities( + self.graph + ) + return { + "clusters": [list(c) for c in communities], + "central_wallets": [ + wallet + for wallet in self.known_wallets + if self.graph.degree[wallet] > 5 + ], + } + + +class TransactionVisualizer: + def __init__(self): + self.transaction_history = [] + + def add_transaction(self, transaction: Transaction): + self.transaction_history.append(asdict(transaction)) + + def generate_volume_chart(self) -> str: + volumes = [tx["amount"] for tx in self.transaction_history] + plt.figure(figsize=(12, 6)) + plt.plot(volumes) + plt.title("Transaction Volume Over Time") + plt.savefig("volume_chart.png") + return "volume_chart.png" + + def generate_network_graph( + self, wallet_analyzer: WalletClusterAnalyzer + ) -> str: + plt.figure(figsize=(15, 15)) + pos = nx.spring_layout(wallet_analyzer.graph) + nx.draw( + wallet_analyzer.graph, + pos, + node_size=1000, + node_color="lightblue", + with_labels=True, + ) + plt.savefig("network_graph.png") + return "network_graph.png" + + +class SolanaMultiAgentAnalyzer: + def __init__( + self, + min_amount: float = 50.0, + websocket_url: str = "wss://api.mainnet-beta.solana.com", + alert_email: str = None, + ): + self.rpc = SolanaRPC() + self.websocket_url = websocket_url + self.min_amount = min_amount + self.transactions = [] + + self.wallet_analyzer = WalletClusterAnalyzer() + self.visualizer = TransactionVisualizer() + self.alert_system = ( + AlertSystem(alert_email) if alert_email else None + ) + + self.trend_agent = Agent( + agent_name="trend-analyzer", + system_prompt=TREND_AGENT_PROMPT, + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, + ) + + self.risk_agent = Agent( + agent_name="risk-analyzer", + system_prompt=RISK_AGENT_PROMPT, + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, + ) + + self.summary_agent = Agent( + agent_name="summary-agent", + system_prompt=SUMMARY_AGENT_PROMPT, + model_name="gpt-4o-mini", + max_loops=1, + streaming_on=True, + ) + + logger.add( + "solana_analysis.log", rotation="500 MB", level="INFO" + ) + + async def start_websocket_stream(self): + async with websockets.connect( + self.websocket_url + ) as websocket: + subscribe_message = { + "jsonrpc": "2.0", + "id": 1, + "method": "programSubscribe", + "params": [ + "11111111111111111111111111111111", + {"encoding": "json", "commitment": "confirmed"}, + ], + } + await websocket.send(json.dumps(subscribe_message)) + + while True: + try: + msg = await websocket.recv() + transaction = await self.parse_websocket_message( + msg + ) + if ( + transaction + and transaction.amount >= self.min_amount + ): + await self.process_transaction(transaction) + except Exception as e: + logger.error(f"Websocket error: {e}") + await asyncio.sleep(5) + + async def parse_websocket_message( + self, msg: str + ) -> Optional[Transaction]: + try: + data = json.loads(msg) + if "params" in data and "result" in data["params"]: + tx_data = data["params"]["result"] + return Transaction( + signature=tx_data["signature"], + timestamp=datetime.fromtimestamp( + tx_data["blockTime"] + ), + amount=float( + tx_data["meta"]["postBalances"][0] + - tx_data["meta"]["preBalances"][0] + ) + / 1e9, + from_address=tx_data["transaction"]["message"][ + "accountKeys" + ][0], + to_address=tx_data["transaction"]["message"][ + "accountKeys" + ][1], + ) + except Exception as e: + logger.error(f"Error parsing websocket message: {e}") + return None + + async def process_transaction(self, transaction: Transaction): + self.wallet_analyzer.update_graph(transaction) + self.visualizer.add_transaction(transaction) + + risk_analysis = await self.risk_agent.run( + f"Analyze risk for transaction: {json.dumps(asdict(transaction))}" + ) + + if self.alert_system: + await self.alert_system.check_and_alert( + transaction, risk_analysis.get("risk_score", 0) + ) + + async def fetch_transactions(self) -> List[Transaction]: + try: + signatures = await self.rpc.get_signatures( + "11111111111111111111111111111111" + ) + transactions = [] + + for sig_info in signatures: + tx_data = await self.rpc.get_transaction( + sig_info["signature"] + ) + if not tx_data or "meta" not in tx_data: + continue + + pre_balances = tx_data["meta"]["preBalances"] + post_balances = tx_data["meta"]["postBalances"] + amount = abs(pre_balances[0] - post_balances[0]) / 1e9 + + if amount >= self.min_amount: + tx = Transaction( + signature=sig_info["signature"], + timestamp=datetime.fromtimestamp( + tx_data["blockTime"] + ), + amount=amount, + from_address=tx_data["transaction"][ + "message" + ]["accountKeys"][0], + to_address=tx_data["transaction"]["message"][ + "accountKeys" + ][1], + ) + transactions.append(tx) + + return transactions + except Exception as e: + logger.error(f"Error fetching transactions: {e}") + return [] + + async def analyze_transactions( + self, transactions: List[Transaction] + ) -> Dict: + tx_data = [asdict(tx) for tx in transactions] + cluster_data = self.wallet_analyzer.identify_clusters() + + trend_analysis = await self.trend_agent.run( + f"Analyze trends in: {json.dumps(tx_data)}" + ) + print(trend_analysis) + + risk_analysis = await self.risk_agent.run( + f"Analyze risks in: {json.dumps({'transactions': tx_data, 'clusters': cluster_data})}" + ) + print(risk_analysis) + + summary = await self.summary_agent.run( + f"Synthesize insights from: {trend_analysis}, {risk_analysis}" + ) + + print(summary) + + volume_chart = self.visualizer.generate_volume_chart() + network_graph = self.visualizer.generate_network_graph( + self.wallet_analyzer + ) + + return { + "transactions": tx_data, + "trend_analysis": trend_analysis, + "risk_analysis": risk_analysis, + "cluster_analysis": cluster_data, + "summary": summary, + "visualizations": { + "volume_chart": volume_chart, + "network_graph": network_graph, + }, + } + + async def run_continuous_analysis(self): + logger.info("Starting continuous analysis") + asyncio.create_task(self.start_websocket_stream()) + + while True: + try: + transactions = await self.fetch_transactions() + if transactions: + analysis = await self.analyze_transactions( + transactions + ) + timestamp = datetime.now().strftime( + "%Y%m%d_%H%M%S" + ) + with open(f"analysis_{timestamp}.json", "w") as f: + json.dump(analysis, f, indent=2, default=str) + logger.info( + f"Analysis completed: analysis_{timestamp}.json" + ) + await asyncio.sleep(60) + except Exception as e: + logger.error(f"Error in analysis loop: {e}") + await asyncio.sleep(60) + + +# Add to __main__: +if __name__ == "__main__": + logger.info("Starting Solana analyzer...") + analyzer = SolanaMultiAgentAnalyzer(alert_email="your@email.com") + try: + asyncio.run(analyzer.run_continuous_analysis()) + except Exception as e: + logger.error(f"Critical error: {e}") diff --git a/swarms/agents/tool_agent.py b/swarms/agents/tool_agent.py index 2d19ec26..b686f3b0 100644 --- a/swarms/agents/tool_agent.py +++ b/swarms/agents/tool_agent.py @@ -1,10 +1,12 @@ from typing import Any, Optional, Callable from swarms.tools.json_former import Jsonformer from swarms.utils.loguru_logger import initialize_logger +from swarms.utils.lazy_loader import lazy_import_decorator logger = initialize_logger(log_folder="tool_agent") +@lazy_import_decorator class ToolAgent: """ Represents a tool agent that performs a specific task using a model and tokenizer. diff --git a/swarms/cli/create_agent.py b/swarms/cli/create_agent.py index 3caaed80..0f536da6 100644 --- a/swarms/cli/create_agent.py +++ b/swarms/cli/create_agent.py @@ -1,16 +1,6 @@ -import os from swarms.structs.agent import Agent -from swarm_models.popular_llms import OpenAIChat from swarms.structs.agent_registry import AgentRegistry -# Get the OpenAI API key from the environment variable -api_key = os.getenv("OPENAI_API_KEY") - -# Create an instance of the OpenAIChat class -model = OpenAIChat( - api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 -) - # Registry of agents agent_registry = AgentRegistry( @@ -19,7 +9,12 @@ agent_registry = AgentRegistry( ) -def create_agent(name: str, system_prompt: str, max_loops: int = 1): +def create_agent( + name: str, + system_prompt: str, + max_loops: int = 1, + model_name: str = "gpt-4o", +): """ Create and initialize an agent with the given parameters. @@ -36,7 +31,7 @@ def create_agent(name: str, system_prompt: str, max_loops: int = 1): agent = Agent( agent_name=name, system_prompt=system_prompt, - llm=model, + model_name=model_name, max_loops=max_loops, autosave=True, dashboard=False, diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index adb33324..5afc1159 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -19,7 +19,6 @@ from swarms.structs.majority_voting import ( most_frequent, parse_code_completion, ) -from swarms.structs.message import Message from swarms.structs.mixture_of_agents import MixtureOfAgents from swarms.structs.multi_agent_collab import MultiAgentCollaboration from swarms.structs.multi_agent_exec import ( @@ -39,7 +38,6 @@ from swarms.structs.round_robin import RoundRobinSwarm from swarms.structs.sequential_workflow import SequentialWorkflow from swarms.structs.spreadsheet_swarm import SpreadSheetSwarm from swarms.structs.swarm_arange import SwarmRearrange -from swarms.structs.swarm_net import SwarmNetwork from swarms.structs.swarm_router import ( SwarmRouter, SwarmType, @@ -90,9 +88,7 @@ __all__ = [ "majority_voting", "most_frequent", "parse_code_completion", - "Message", "MultiAgentCollaboration", - "SwarmNetwork", "AgentRearrange", "rearrange", "RoundRobinSwarm", diff --git a/swarms/structs/agent_router.py b/swarms/structs/agent_router.py index 6cf3c094..a03aa84b 100644 --- a/swarms/structs/agent_router.py +++ b/swarms/structs/agent_router.py @@ -1,14 +1,19 @@ from typing import List, Optional -import chromadb from tenacity import retry, stop_after_attempt, wait_exponential from typing import Union, Callable, Any from swarms import Agent from swarms.utils.loguru_logger import initialize_logger +from swarms.utils.lazy_loader import lazy_import_decorator +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, +) + logger = initialize_logger(log_folder="agent_router") +@lazy_import_decorator class AgentRouter: """ Initialize the AgentRouter. @@ -29,6 +34,14 @@ class AgentRouter: *args, **kwargs, ): + try: + import chromadb + except ImportError: + auto_check_and_download_package( + "chromadb", package_manager="pip", upgrade=True + ) + import chromadb + self.collection_name = collection_name self.n_agents = n_agents self.persist_directory = persist_directory diff --git a/swarms/structs/base_swarm.py b/swarms/structs/base_swarm.py index 6e2242be..29dcccbf 100644 --- a/swarms/structs/base_swarm.py +++ b/swarms/structs/base_swarm.py @@ -16,7 +16,6 @@ from typing import ( import yaml -from swarms_memory import BaseVectorDatabase from swarms.structs.agent import Agent from swarms.structs.conversation import Conversation from swarms.structs.omni_agent_types import AgentType @@ -98,9 +97,7 @@ class BaseSwarm(ABC): agentops_on: Optional[bool] = False, speaker_selection_func: Optional[Callable] = None, rules: Optional[str] = None, - collective_memory_system: Optional[ - BaseVectorDatabase - ] = False, + collective_memory_system: Optional[Any] = False, agent_ops_on: bool = False, output_schema: Optional[BaseModel] = None, *args, diff --git a/swarms/structs/graph_swarm.py b/swarms/structs/graph_swarm.py index 82cef523..a96379e2 100644 --- a/swarms/structs/graph_swarm.py +++ b/swarms/structs/graph_swarm.py @@ -1,10 +1,3 @@ -""" -GraphSwarm: A production-grade framework for orchestrating swarms of agents -Author: Claude -License: MIT -Version: 2.0.0 -""" - import asyncio import json import time @@ -12,13 +5,13 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union -import chromadb import networkx as nx from loguru import logger from pydantic import BaseModel, Field - -from swarms import Agent - +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, +) +from swarms.structs.agent import Agent # Configure logging logger.add( @@ -57,6 +50,15 @@ class SwarmMemory: def __init__(self, collection_name: str = "swarm_memories"): """Initialize SwarmMemory with ChromaDB.""" + + try: + import chromadb + except ImportError: + auto_check_and_download_package( + "chromadb", package_manager="pip", upgrade=True + ) + import chromadb + self.client = chromadb.Client() # Get or create collection diff --git a/swarms/structs/groupchat_new.py b/swarms/structs/groupchat_new.py index 69c424d4..a6aaaa7c 100644 --- a/swarms/structs/groupchat_new.py +++ b/swarms/structs/groupchat_new.py @@ -3,7 +3,6 @@ import asyncio from pydantic import BaseModel, Field from typing import List, Dict, Any from swarms import Agent -from swarm_models import OpenAIChat from dotenv import load_dotenv from swarms.utils.formatter import formatter @@ -181,64 +180,64 @@ class GroupChat: ] -# Example Usage -if __name__ == "__main__": - - load_dotenv() - - # Get the OpenAI API key from the environment variable - api_key = os.getenv("OPENAI_API_KEY") - - # Create an instance of the OpenAIChat class - model = OpenAIChat( - openai_api_key=api_key, - model_name="gpt-4o-mini", - temperature=0.1, - ) - - # Example agents - agent1 = Agent( - agent_name="Financial-Analysis-Agent", - system_prompt="You are a financial analyst specializing in investment strategies.", - llm=model, - max_loops=1, - autosave=False, - dashboard=False, - verbose=True, - dynamic_temperature_enabled=True, - user_name="swarms_corp", - retry_attempts=1, - context_length=200000, - output_type="string", - streaming_on=False, - ) - - agent2 = Agent( - agent_name="Tax-Adviser-Agent", - system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.", - llm=model, - max_loops=1, - autosave=False, - dashboard=False, - verbose=True, - dynamic_temperature_enabled=True, - user_name="swarms_corp", - retry_attempts=1, - context_length=200000, - output_type="string", - streaming_on=False, - ) - - # Create group chat - group_chat = GroupChat( - name="Financial Discussion", - description="A group chat for financial analysis and tax advice.", - agents=[agent1, agent2], - ) - - # Run the group chat - asyncio.run( - group_chat.run( - "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria? What do you guys think?" - ) - ) +# # Example Usage +# if __name__ == "__main__": + +# load_dotenv() + +# # Get the OpenAI API key from the environment variable +# api_key = os.getenv("OPENAI_API_KEY") + +# # Create an instance of the OpenAIChat class +# model = OpenAIChat( +# openai_api_key=api_key, +# model_name="gpt-4o-mini", +# temperature=0.1, +# ) + +# # Example agents +# agent1 = Agent( +# agent_name="Financial-Analysis-Agent", +# system_prompt="You are a financial analyst specializing in investment strategies.", +# llm=model, +# max_loops=1, +# autosave=False, +# dashboard=False, +# verbose=True, +# dynamic_temperature_enabled=True, +# user_name="swarms_corp", +# retry_attempts=1, +# context_length=200000, +# output_type="string", +# streaming_on=False, +# ) + +# agent2 = Agent( +# agent_name="Tax-Adviser-Agent", +# system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.", +# llm=model, +# max_loops=1, +# autosave=False, +# dashboard=False, +# verbose=True, +# dynamic_temperature_enabled=True, +# user_name="swarms_corp", +# retry_attempts=1, +# context_length=200000, +# output_type="string", +# streaming_on=False, +# ) + +# # Create group chat +# group_chat = GroupChat( +# name="Financial Discussion", +# description="A group chat for financial analysis and tax advice.", +# agents=[agent1, agent2], +# ) + +# # Run the group chat +# asyncio.run( +# group_chat.run( +# "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria? What do you guys think?" +# ) +# ) diff --git a/swarms/structs/message.py b/swarms/structs/message.py deleted file mode 100644 index ae686790..00000000 --- a/swarms/structs/message.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Dict, Optional -from datetime import datetime -from pydantic import BaseModel, Field - - -class Message(BaseModel): - """ - Represents a message with timestamp and optional metadata. - - Usage - -------------- - mes = Message( - sender = "Kye", - content = "message" - ) - - print(mes) - """ - - timestamp: datetime = Field(default_factory=datetime.now) - sender: str - content: str - metadata: Optional[Dict[str, str]] = {} - - def __repr__(self) -> str: - """ - __repr__ means... - """ - return f"{self.timestamp} - {self.sender}: {self.content}" diff --git a/swarms/structs/multi_agent_exec.py b/swarms/structs/multi_agent_exec.py index 839e9e45..ef87a5d8 100644 --- a/swarms/structs/multi_agent_exec.py +++ b/swarms/structs/multi_agent_exec.py @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor import psutil from dataclasses import dataclass import threading -from typing import List, Union, Any, Callable +from typing import List, Any from multiprocessing import cpu_count import os diff --git a/swarms/structs/output_types.py b/swarms/structs/output_types.py new file mode 100644 index 00000000..7e4a4644 --- /dev/null +++ b/swarms/structs/output_types.py @@ -0,0 +1,15 @@ +from typing import Literal + +# Literal of output types +# Literal of output types +OutputType = Literal[ + "all", + "final", + "list", + "dict", + ".json", + ".md", + ".txt", + ".yaml", + ".toml", +] diff --git a/swarms/structs/rearrange.py b/swarms/structs/rearrange.py index 801861b0..8fc4ecca 100644 --- a/swarms/structs/rearrange.py +++ b/swarms/structs/rearrange.py @@ -3,10 +3,9 @@ import traceback import uuid from concurrent.futures import ThreadPoolExecutor from datetime import datetime -from typing import Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel, Field -from swarms_memory import BaseVectorDatabase from swarms.schemas.agent_step_schemas import ManySteps from swarms.structs.agent import Agent @@ -17,22 +16,10 @@ from swarms.utils.loguru_logger import initialize_logger from swarms.utils.wrapper_clusterop import ( exec_callable_with_clusterops, ) +from swarms.structs.output_types import OutputType logger = initialize_logger(log_folder="rearrange") -# Literal of output types -OutputType = Literal[ - "all", - "final", - "list", - "dict", - ".json", - ".md", - ".txt", - ".yaml", - ".toml", -] - def swarm_id(): return uuid.uuid4().hex @@ -112,7 +99,7 @@ class AgentRearrange(BaseSwarm): flow: str = None, max_loops: int = 1, verbose: bool = True, - memory_system: BaseVectorDatabase = None, + memory_system: Any = None, human_in_the_loop: bool = False, custom_human_in_the_loop: Optional[ Callable[[str], str] diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index ed55102d..61cdbb0e 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -1,6 +1,7 @@ from typing import List, Optional from swarms.structs.agent import Agent -from swarms.structs.rearrange import AgentRearrange, OutputType +from swarms.structs.rearrange import AgentRearrange +from swarms.structs.output_types import OutputType from concurrent.futures import ThreadPoolExecutor, as_completed from swarms.utils.loguru_logger import initialize_logger diff --git a/swarms/structs/swarm_matcher.py b/swarms/structs/swarm_matcher.py index c4d0711f..21b973a7 100644 --- a/swarms/structs/swarm_matcher.py +++ b/swarms/structs/swarm_matcher.py @@ -1,11 +1,14 @@ from typing import List, Tuple, Optional import numpy as np -import torch -from transformers import AutoTokenizer, AutoModel +from swarms.utils.lazy_loader import lazy_import_decorator from pydantic import BaseModel, Field import json from tenacity import retry, stop_after_attempt, wait_exponential from swarms.utils.loguru_logger import initialize_logger +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, +) + logger = initialize_logger(log_folder="swarm_matcher") @@ -25,6 +28,7 @@ class SwarmMatcherConfig(BaseModel): ) +@lazy_import_decorator class SwarmMatcher: """ A class for matching tasks to swarm types based on their descriptions. @@ -41,12 +45,34 @@ class SwarmMatcher: """ logger.add("swarm_matcher_debug.log", level="DEBUG") logger.debug("Initializing SwarmMatcher") + + try: + import torch + except ImportError: + auto_check_and_download_package( + "torch", package_manager="pip", upgrade=True + ) + import torch + + try: + import transformers + except ImportError: + auto_check_and_download_package( + "transformers", package_manager="pip", upgrade=True + ) + import transformers + + self.torch = torch try: self.config = config - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer = ( + transformers.AutoTokenizer.from_pretrained( + config.model_name + ) + ) + self.model = transformers.AutoModel.from_pretrained( config.model_name ) - self.model = AutoModel.from_pretrained(config.model_name) self.swarm_types: List[SwarmType] = [] logger.debug("SwarmMatcher initialized successfully") except Exception as e: @@ -76,7 +102,7 @@ class SwarmMatcher: truncation=True, max_length=512, ) - with torch.no_grad(): + with self.torch.no_grad(): outputs = self.model(**inputs) embedding = ( outputs.last_hidden_state.mean(dim=1) @@ -244,6 +270,7 @@ def initialize_swarm_types(matcher: SwarmMatcher): logger.debug("Swarm types initialized") +@lazy_import_decorator def swarm_matcher(task: str, *args, **kwargs): """ Runs the SwarmMatcher example with predefined tasks and swarm types. diff --git a/swarms/structs/swarm_net.py b/swarms/structs/swarm_net.py deleted file mode 100644 index dac0d0a2..00000000 --- a/swarms/structs/swarm_net.py +++ /dev/null @@ -1,511 +0,0 @@ -""" -Todo -- [ ] Test the new api feature -- [ ] Add the agent schema for every agent -- following OpenAI assistaants schema -- [ ] then add the swarm schema for the swarm url: /v1/swarms/{swarm_name}/agents/{agent_id} -- [ ] Add the agent schema for the agent url: /v1/swarms/{swarm_name}/agents/{agent_id} -""" - -import asyncio -import multiprocessing -import queue -import threading -from typing import List, Optional - -import tenacity - -# from fastapi import FastAPI -from pydantic import BaseModel - -from swarms.structs.agent import Agent -from swarms.structs.base_swarm import BaseSwarm -from swarms.utils.loguru_logger import initialize_logger - -logger = initialize_logger("swarm-network") - - -# Pydantic models -class TaskRequest(BaseModel): - task: str - - -# Pydantic models -class TaskResponse(BaseModel): - result: str - - -class AgentInfo(BaseModel): - agent_name: str - agent_description: str - - -class SwarmInfo(BaseModel): - swarm_name: str - swarm_description: str - agents: List[AgentInfo] - - -# Helper function to get the number of workers -def get_number_of_workers(): - return multiprocessing.cpu_count() - - -# [TODO] Add the agent schema for every agent -- following OpenAI assistaants schema -class SwarmNetwork(BaseSwarm): - """ - SwarmNetwork class - - The SwarmNetwork class is responsible for managing the agents pool - and the task queue. It also monitors the health of the agents and - scales the pool up or down based on the number of pending tasks - and the current load of the agents. - - For example, if the number of pending tasks is greater than the - number of agents in the pool, the SwarmNetwork will scale up the - pool by adding new agents. If the number of pending tasks is less - than the number of agents in the pool, the SwarmNetwork will scale - down the pool by removing agents. - - The SwarmNetwork class also provides a simple API for interacting - with the agents pool. The API is implemented using the Flask - framework and is enabled by default. The API can be disabled by - setting the `api_enabled` parameter to False. - - Features: - - Agent pool management - - Task queue management - - Agent health monitoring - - Agent pool scaling - - Simple API for interacting with the agent pool - - Simple API for interacting with the task queue - - Simple API for interacting with the agent health monitor - - Simple API for interacting with the agent pool scaler - - Create APIs for each agent in the pool (optional) - - Run each agent on it's own thread - - Run each agent on it's own process - - Run each agent on it's own container - - Run each agent on it's own machine - - Run each agent on it's own cluster - - - Attributes: - task_queue (queue.Queue): A queue for storing tasks. - idle_threshold (float): The idle threshold for the agents. - busy_threshold (float): The busy threshold for the agents. - agents (List[Agent]): A list of agents in the pool. - api_enabled (bool): A flag to enable/disable the API. - logging_enabled (bool): A flag to enable/disable logging. - - Example: - >>> from swarms.structs.agent import Agent - >>> from swarms.structs.swarm_net import SwarmNetwork - >>> agent = Agent() - >>> swarm = SwarmNetwork(agents=[agent]) - >>> swarm.add_task("task") - >>> swarm.run() - - """ - - def __init__( - self, - name: str = None, - description: str = None, - agents: List[Agent] = None, - idle_threshold: float = 0.2, - busy_threshold: float = 0.7, - api_enabled: Optional[bool] = False, - logging_enabled: Optional[bool] = False, - api_on: Optional[bool] = False, - host: str = "0.0.0.0", - port: int = 8000, - swarm_callable: Optional[callable] = None, - *args, - **kwargs, - ): - super().__init__(agents=agents, *args, **kwargs) - self.name = name - self.description = description - self.agents = agents - self.task_queue = queue.Queue() - self.idle_threshold = idle_threshold - self.busy_threshold = busy_threshold - self.lock = threading.Lock() - self.api_enabled = api_enabled - self.logging_enabled = logging_enabled - self.host = host - self.port = port - self.swarm_callable = swarm_callable - - # Ensure that the agents list is not empty - if not agents: - raise ValueError("The agents list cannot be empty") - - # Create a dictionary of agents for easy access - self.agent_dict = {agent.id: agent for agent in agents} - - # # Create the FastAPI instance - # if api_on is True: - # logger.info("Creating FastAPI instance") - # self.app = FastAPI(debug=True, *args, **kwargs) - - # self.app.add_middleware( - # CORSMiddleware, - # allow_origins=["*"], - # allow_credentials=True, - # allow_methods=["*"], - # allow_headers=["*"], - # ) - - # logger.info("Routes set for creation") - # self._create_routes() - - def add_task(self, task): - """Add task to the task queue - - Args: - task (_type_): _description_ - - Example: - >>> from swarms.structs.agent import Agent - >>> from swarms.structs.swarm_net import SwarmNetwork - >>> agent = Agent() - >>> swarm = SwarmNetwork(agents=[agent]) - >>> swarm.add_task("task") - """ - self.logger.info(f"Adding task {task} to queue") - try: - self.task_queue.put(task) - self.logger.info(f"Task {task} added to queue") - except Exception as error: - print( - f"Error adding task to queue: {error} try again with" - " a new task" - ) - raise error - - async def async_add_task(self, task): - """Add task to the task queue - - Args: - task (_type_): _description_ - - Example: - >>> from swarms.structs.agent import Agent - >>> from swarms.structs.swarm_net import SwarmNetwork - >>> agent = Agent() - >>> swarm = SwarmNetwork(agents=[agent]) - >>> swarm.add_task("task") - - """ - self.logger.info( - f"Adding task {task} to queue asynchronously" - ) - try: - # Add task to queue asynchronously with asyncio - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.task_queue.put, task - ) - self.logger.info(f"Task {task} added to queue") - except Exception as error: - print( - f"Error adding task to queue: {error} try again with" - " a new task" - ) - raise error - - # def _create_routes(self) -> None: - # """ - # Creates the routes for the API. - # """ - # # Extensive logginbg - # logger.info("Creating routes for the API") - - # # Routes available - # logger.info( - # "Routes available: /v1/swarms, /v1/health, /v1/swarms/{swarm_name}/agents/{agent_id}, /v1/swarms/{swarm_name}/run" - # ) - - # @self.app.get("/v1/swarms", response_model=SwarmInfo) - # async def get_swarms() -> SwarmInfo: - # try: - # logger.info("Getting swarm information") - # return SwarmInfo( - # swarm_name=self.swarm_name, - # swarm_description=self.swarm_description, - # agents=[ - # AgentInfo( - # agent_name=agent.agent_name, - # agent_description=agent.agent_description, - # ) - # for agent in self.agents - # ], - # ) - # except Exception as e: - # logger.error(f"Error getting swarm information: {str(e)}") - # raise HTTPException( - # status_code=500, detail="Internal Server Error" - # ) - - # @self.app.get("/v1/health") - # async def get_health() -> Dict[str, str]: - # try: - # logger.info("Checking health status") - # return {"status": "healthy"} - # except Exception as e: - # logger.error(f"Error checking health status: {str(e)}") - # raise HTTPException( - # status_code=500, detail="Internal Server Error" - # ) - - # @self.app.get(f"/v1/swarms/{self.swarm_name}/agents/{{agent_id}}") - # async def get_agent_info(agent_id: str) -> AgentInfo: - # try: - # logger.info(f"Getting information for agent {agent_id}") - # agent = self.agent_dict.get(agent_id) - # if not agent: - # raise HTTPException( - # status_code=404, detail="Agent not found" - # ) - # return AgentInfo( - # agent_name=agent.agent_name, - # agent_description=agent.agent_description, - # ) - # except Exception as e: - # logger.error(f"Error getting agent information: {str(e)}") - # raise HTTPException( - # status_code=500, detail="Internal Server Error" - # ) - - # @self.app.post( - # f"/v1/swarms/{self.swarm_name}/agents/{{agent_id}}/run", - # response_model=TaskResponse, - # ) - # async def run_agent_task( - # task_request: TaskRequest, - # ) -> TaskResponse: - # try: - # logger.info("Running agent task") - # # Assuming only one agent in the swarm for this example - # agent = self.agents[0] - # logger.info(f"Running agent task: {task_request.task}") - # result = agent.run(task_request.task) - # return TaskResponse(result=result) - # except Exception as e: - # logger.error(f"Error running agent task: {str(e)}") - # raise HTTPException( - # status_code=500, detail="Internal Server Error" - # ) - - # def get_app(self) -> FastAPI: - # """ - # Returns the FastAPI instance. - - # Returns: - # FastAPI: The FastAPI instance. - # """ - # return self.app - - def run_single_agent( - self, agent_id, task: Optional[str], *args, **kwargs - ): - """Run agent the task on the agent id - - Args: - agent_id (_type_): _description_ - task (str, optional): _description_. Defaults to None. - - Raises: - ValueError: _description_ - - Returns: - _type_: _description_ - """ - self.logger.info(f"Running task {task} on agent {agent_id}") - try: - for agent in self.agents: - if agent.id == agent_id: - out = agent.run(task, *args, **kwargs) - return out - except Exception as error: - self.logger.error(f"Error running task on agent: {error}") - raise error - - def run_many_agents( - self, task: Optional[str] = None, *args, **kwargs - ) -> List: - """Run the task on all agents - - Args: - task (str, optional): _description_. Defaults to None. - - Returns: - List: _description_ - """ - self.logger.info(f"Running task {task} on all agents") - try: - return [ - agent.run(task, *args, **kwargs) - for agent in self.agents - ] - except Exception as error: - logger.error(f"Error running task on agents: {error}") - raise error - - def list_agents(self): - """List all agents.""" - self.logger.info("[Listing all active agents]") - - try: - # Assuming self.agents is a list of agent objects - for agent in self.agents: - self.logger.info( - f"[Agent] [ID: {agent.id}] [Name:" - f" {agent.agent_name}] [Description:" - f" {agent.agent_description}] [Status: Running]" - ) - except Exception as error: - self.logger.error(f"Error listing agents: {error}") - raise - - def get_agent(self, agent_id): - """Get agent by id - - Args: - agent_id (_type_): _description_ - - Returns: - _type_: _description_ - """ - self.logger.info(f"Getting agent {agent_id}") - - try: - for agent in self.agents: - if agent.id == agent_id: - return agent - raise ValueError(f"No agent found with ID {agent_id}") - except Exception as error: - self.logger.error(f"Error getting agent: {error}") - raise error - - def add_agent(self, agent: Agent): - """Add agent to the agent pool - - Args: - agent (_type_): _description_ - """ - self.logger.info(f"Adding agent {agent} to pool") - try: - self.agents.append(agent) - except Exception as error: - print(f"Error adding agent to pool: {error}") - raise error - - def remove_agent(self, agent_id): - """Remove agent from the agent pool - - Args: - agent_id (_type_): _description_ - """ - self.logger.info(f"Removing agent {agent_id} from pool") - try: - for agent in self.agents: - if agent.id == agent_id: - self.agents.remove(agent) - return - raise ValueError(f"No agent found with ID {agent_id}") - except Exception as error: - print(f"Error removing agent from pool: {error}") - raise error - - async def async_remove_agent(self, agent_id): - """Remove agent from the agent pool - - Args: - agent_id (_type_): _description_ - """ - self.logger.info(f"Removing agent {agent_id} from pool") - try: - # Remove agent from pool asynchronously with asyncio - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.remove_agent, agent_id - ) - except Exception as error: - print(f"Error removing agent from pool: {error}") - raise error - - def scale_up(self, num_agents: int = 1): - """Scale up the agent pool - - Args: - num_agents (int, optional): _description_. Defaults to 1. - """ - self.logger.info(f"Scaling up agent pool by {num_agents}") - try: - for _ in range(num_agents): - self.agents.append(Agent()) - except Exception as error: - print(f"Error scaling up agent pool: {error}") - raise error - - def scale_down(self, num_agents: int = 1): - """Scale down the agent pool - - Args: - num_agents (int, optional): _description_. Defaults to 1. - """ - for _ in range(num_agents): - self.agents.pop() - - @tenacity.retry( - wait=tenacity.wait_fixed(1), - stop=tenacity.stop_after_attempt(3), - retry=tenacity.retry_if_exception_type(Exception), - ) - def run(self, *args, **kwargs): - """run the swarm network""" - app = self.get_app() - - try: - import uvicorn - - logger.info( - f"Running the swarm network with {len(self.agents)} on {self.host}:{self.port}" - ) - uvicorn.run( - app, - host=self.host, - port=self.port, - # workers=get_number_of_workers(), - *args, - **kwargs, - ) - - return app - except Exception as error: - logger.error(f"Error running the swarm network: {error}") - raise error - - -# # # Example usage -# if __name__ == "__main__": - -# agent1 = Agent( -# agent_name="Covid-19-Chat", -# agent_description="This agent provides information about COVID-19 symptoms.", -# llm=OpenAIChat(), -# max_loops="auto", -# autosave=True, -# verbose=True, -# stopping_condition="finish", -# ) - -# agents = [agent1] # Add more agents as needed -# swarm_name = "HealthSwarm" -# swarm_description = ( -# "A swarm of agents providing health-related information." -# ) - -# agent_api = SwarmNetwork(swarm_name, swarm_description, agents) -# agent_api.run() diff --git a/swarms/structs/tree_swarm.py b/swarms/structs/tree_swarm.py index 56b46642..75b0bf13 100644 --- a/swarms/structs/tree_swarm.py +++ b/swarms/structs/tree_swarm.py @@ -4,17 +4,14 @@ from datetime import datetime from typing import Any, List, Optional from pydantic import BaseModel, Field -from sentence_transformers import SentenceTransformer, util - from swarms.structs.agent import Agent from swarms.utils.loguru_logger import initialize_logger +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, +) -logger = initialize_logger(log_folder="tree_swarm") -# Pretrained model for embeddings -embedding_model = SentenceTransformer( - "all-MiniLM-L6-v2" -) # A small, fast model for embedding +logger = initialize_logger(log_folder="tree_swarm") # Pydantic Models for Logging @@ -68,7 +65,7 @@ class TreeAgent(Agent): name: str = None, description: str = None, system_prompt: str = None, - llm: callable = None, + model_name: str = "gpt-4o", agent_name: Optional[str] = None, *args, **kwargs, @@ -78,12 +75,29 @@ class TreeAgent(Agent): name=name, description=description, system_prompt=system_prompt, - llm=llm, + model_name=model_name, agent_name=agent_name, *args, **kwargs, ) - self.system_prompt_embedding = embedding_model.encode( + + try: + import sentence_transformers + except ImportError: + auto_check_and_download_package( + "sentence-transformers", package_manager="pip" + ) + import sentence_transformers + + self.sentence_transformers = sentence_transformers + + # Pretrained model for embeddings + self.embedding_model = ( + sentence_transformers.SentenceTransformer( + "all-MiniLM-L6-v2" + ) + ) + self.system_prompt_embedding = self.embedding_model.encode( system_prompt, convert_to_tensor=True ) @@ -103,7 +117,7 @@ class TreeAgent(Agent): Returns: float: Distance score between 0 and 1, with 0 being close and 1 being far. """ - similarity = util.pytorch_cos_sim( + similarity = self.sentence_transformers.util.pytorch_cos_sim( self.system_prompt_embedding, other_agent.system_prompt_embedding, ).item() @@ -154,12 +168,14 @@ class TreeAgent(Agent): # Perform embedding similarity match if keyword match is not found if not keyword_match: - task_embedding = embedding_model.encode( + task_embedding = self.embedding_model.encode( task, convert_to_tensor=True ) - similarity = util.pytorch_cos_sim( - self.system_prompt_embedding, task_embedding - ).item() + similarity = ( + self.sentence_transformers.util.pytorch_cos_sim( + self.system_prompt_embedding, task_embedding + ).item() + ) logger.info( f"Semantic similarity between task and {self.agent_name}: {similarity:.2f}" ) diff --git a/swarms/tools/__init__.py b/swarms/tools/__init__.py index ee68bd90..18ac51ac 100644 --- a/swarms/tools/__init__.py +++ b/swarms/tools/__init__.py @@ -28,6 +28,7 @@ from swarms.tools.cohere_func_call_schema import ( ParameterDefinition, ) from swarms.tools.tool_registry import ToolStorage, tool_registry +from swarms.tools.json_utils import base_model_to_json __all__ = [ @@ -51,4 +52,5 @@ __all__ = [ "ParameterDefinition", "ToolStorage", "tool_registry", + "base_model_to_json", ] diff --git a/swarms/tools/json_former.py b/swarms/tools/json_former.py index dcca9932..6e1358a9 100644 --- a/swarms/tools/json_former.py +++ b/swarms/tools/json_former.py @@ -1,7 +1,7 @@ import json from typing import Any, Dict, List, Union -from transformers import PreTrainedModel, PreTrainedTokenizer +from swarms.utils.lazy_loader import lazy_import_decorator from pydantic import BaseModel from swarms.tools.logits_processor import ( NumberStoppingCriteria, @@ -9,10 +9,23 @@ from swarms.tools.logits_processor import ( StringStoppingCriteria, ) from swarm_models.base_llm import BaseLLM +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, +) + +try: + import transformers +except ImportError: + auto_check_and_download_package( + "transformers", package_manager="pip" + ) + import transformers + GENERATION_MARKER = "|GENERATION|" +@lazy_import_decorator class Jsonformer: """ Initializes the FormatTools class. @@ -35,8 +48,8 @@ class Jsonformer: def __init__( self, - model: PreTrainedModel = None, - tokenizer: PreTrainedTokenizer = None, + model: transformers.PreTrainedModel = None, # type: ignore + tokenizer: transformers.PreTrainedTokenizer = None, # type: ignore json_schema: Union[Dict[str, Any], BaseModel] = None, schemas: List[Union[Dict[str, Any], BaseModel]] = [], prompt: str = None, diff --git a/swarms/tools/logits_processor.py b/swarms/tools/logits_processor.py index f67ff451..47978bc5 100644 --- a/swarms/tools/logits_processor.py +++ b/swarms/tools/logits_processor.py @@ -1,21 +1,35 @@ -import torch -from transformers import ( - LogitsWarper, - PreTrainedTokenizer, - StoppingCriteria, +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, ) -class StringStoppingCriteria(StoppingCriteria): +try: + import torch +except ImportError: + auto_check_and_download_package( + "torch", package_manager="pip", upgrade=True + ) + import torch + +try: + import transformers +except ImportError: + auto_check_and_download_package( + "transformers", package_manager="pip", upgrade=True + ) + import transformers + + +class StringStoppingCriteria(transformers.StoppingCriteria): def __init__( - self, tokenizer: PreTrainedTokenizer, prompt_length: int + self, tokenizer: transformers.PreTrainedTokenizer, prompt_length: int # type: ignore ): self.tokenizer = tokenizer self.prompt_length = prompt_length def __call__( self, - input_ids: torch.LongTensor, + input_ids: torch.LongTensor, # type: ignore _, ) -> bool: if len(input_ids[0]) <= self.prompt_length: @@ -31,10 +45,10 @@ class StringStoppingCriteria(StoppingCriteria): return result -class NumberStoppingCriteria(StoppingCriteria): +class NumberStoppingCriteria(transformers.StoppingCriteria): def __init__( self, - tokenizer: PreTrainedTokenizer, + tokenizer: transformers.PreTrainedTokenizer, # type: ignore prompt_length: int, precision: int = 3, ): @@ -44,8 +58,8 @@ class NumberStoppingCriteria(StoppingCriteria): def __call__( self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, + input_ids: torch.LongTensor, # type: ignore + scores: torch.FloatTensor, # type: ignore ) -> bool: decoded = self.tokenizer.decode( input_ids[0][self.prompt_length :], @@ -71,8 +85,8 @@ class NumberStoppingCriteria(StoppingCriteria): return False -class OutputNumbersTokens(LogitsWarper): - def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): +class OutputNumbersTokens(transformers.LogitsWarper): + def __init__(self, tokenizer: transformers.PreTrainedTokenizer, prompt: str): # type: ignore self.tokenizer = tokenizer self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") vocab_size = len(tokenizer) diff --git a/swarms/utils/auto_download_check_packages.py b/swarms/utils/auto_download_check_packages.py new file mode 100644 index 00000000..555967a3 --- /dev/null +++ b/swarms/utils/auto_download_check_packages.py @@ -0,0 +1,146 @@ +""" +Package installation utility that checks for package existence and installs if needed. +Supports both pip and conda package managers. +""" + +import importlib.util +import subprocess +import sys +from typing import Literal, Optional, Union +from swarms.utils.loguru_logger import initialize_logger +import pkg_resources + + +logger = initialize_logger("autocheckpackages") + + +def check_and_install_package( + package_name: str, + package_manager: Literal["pip", "conda"] = "pip", + version: Optional[str] = None, + upgrade: bool = False, +) -> bool: + """ + Check if a package is installed and install it if not found. + + Args: + package_name: Name of the package to check/install + package_manager: Package manager to use ('pip' or 'conda') + version: Specific version to install (optional) + upgrade: Whether to upgrade the package if it exists + + Returns: + bool: True if package is available after check/install, False if installation failed + + Raises: + ValueError: If invalid package manager is specified + """ + try: + # Check if package exists + if package_manager == "pip": + try: + pkg_resources.get_distribution(package_name) + if not upgrade: + logger.info( + f"Package {package_name} is already installed" + ) + return True + except pkg_resources.DistributionNotFound: + pass + + # Construct installation command + cmd = [sys.executable, "-m", "pip", "install"] + if upgrade: + cmd.append("--upgrade") + + if version: + cmd.append(f"{package_name}=={version}") + else: + cmd.append(package_name) + + elif package_manager == "conda": + # Check if conda is available + try: + subprocess.run( + ["conda", "--version"], + check=True, + capture_output=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + logger.error( + "Conda is not available. Please install conda first." + ) + return False + + # Construct conda command + cmd = ["conda", "install", "-y"] + if version: + cmd.append(f"{package_name}={version}") + else: + cmd.append(package_name) + else: + raise ValueError( + f"Invalid package manager: {package_manager}" + ) + + # Run installation + logger.info(f"Installing {package_name}...") + subprocess.run( + cmd, check=True, capture_output=True, text=True + ) + + # Verify installation + try: + importlib.import_module(package_name) + logger.info(f"Successfully installed {package_name}") + return True + except ImportError: + logger.error( + f"Package {package_name} was installed but cannot be imported" + ) + return False + + except subprocess.CalledProcessError as e: + logger.error(f"Failed to install {package_name}: {e.stderr}") + return False + except Exception as e: + logger.error( + f"Unexpected error while installing {package_name}: {str(e)}" + ) + return False + + +def auto_check_and_download_package( + packages: Union[str, list[str]], + package_manager: Literal["pip", "conda"] = "pip", + upgrade: bool = False, +) -> bool: + """ + Ensure multiple packages are installed. + + Args: + packages: Single package name or list of package names + package_manager: Package manager to use ('pip' or 'conda') + upgrade: Whether to upgrade existing packages + + Returns: + bool: True if all packages are available, False if any installation failed + """ + if isinstance(packages, str): + packages = [packages] + + success = True + for package in packages: + if ":" in package: + name, version = package.split(":") + if not check_and_install_package( + name, package_manager, version, upgrade + ): + success = False + else: + if not check_and_install_package( + package, package_manager, upgrade=upgrade + ): + success = False + + return success diff --git a/swarms/utils/lazy_loader.py b/swarms/utils/lazy_loader.py new file mode 100644 index 00000000..c9725e51 --- /dev/null +++ b/swarms/utils/lazy_loader.py @@ -0,0 +1,263 @@ +""" +Lazy Package Loader + +This module provides utilities for lazy loading Python packages to improve startup time +and reduce memory usage by only importing packages when they are actually used. + +Features: +- Type-safe lazy loading of packages +- Support for nested module imports +- Auto-completion support in IDEs +- Thread-safe implementation +- Comprehensive test coverage +""" + +from types import ModuleType +from typing import ( + Optional, + Dict, + Any, + Callable, + Type, + TypeVar, + Union, + cast, +) +import importlib +import functools +import threading +from importlib.util import find_spec +from swarms.utils.auto_download_check_packages import ( + auto_check_and_download_package, +) + + +T = TypeVar("T") +C = TypeVar("C") + + +class ImportError(Exception): + """Raised when a lazy import fails.""" + + pass + + +class LazyLoader: + """ + A thread-safe lazy loader for Python packages that only imports them when accessed. + + Attributes: + _module_name (str): The name of the module to be lazily loaded + _module (Optional[ModuleType]): The cached module instance once loaded + _lock (threading.Lock): Thread lock for safe concurrent access + + Examples: + >>> np = LazyLoader('numpy') + >>> # numpy is not imported yet + >>> result = np.array([1, 2, 3]) + >>> # numpy is imported only when first used + """ + + def __init__(self, module_name: str) -> None: + """ + Initialize the lazy loader with a module name. + + Args: + module_name: The fully qualified name of the module to lazily load + + Raises: + ImportError: If the module cannot be found in sys.path + """ + self._module_name = module_name + self._module: Optional[ModuleType] = None + self._lock = threading.Lock() + + auto_check_and_download_package( + module_name, package_manager="pip" + ) + + # Verify module exists without importing it + if find_spec(module_name) is None: + raise ImportError( + f"Module '{module_name}' not found in sys.path" + ) + + def _load_module(self) -> ModuleType: + """ + Thread-safe module loading. + + Returns: + ModuleType: The loaded module + + Raises: + ImportError: If module import fails + """ + if self._module is None: + with self._lock: + # Double-check pattern + if self._module is None: + try: + self._module = importlib.import_module( + self._module_name + ) + except Exception as e: + raise ImportError( + f"Failed to import '{self._module_name}': {str(e)}" + ) + return cast(ModuleType, self._module) + + def __getattr__(self, name: str) -> Any: + """ + Intercepts attribute access to load the module if needed. + + Args: + name: The attribute name being accessed + + Returns: + Any: The requested attribute from the loaded module + + Raises: + AttributeError: If the attribute doesn't exist in the module + """ + module = self._load_module() + try: + return getattr(module, name) + except AttributeError: + raise AttributeError( + f"Module '{self._module_name}' has no attribute '{name}'" + ) + + def __dir__(self) -> list[str]: + """ + Returns list of attributes for autocomplete support. + + Returns: + List[str]: Available attributes in the module + """ + return dir(self._load_module()) + + def is_loaded(self) -> bool: + """ + Check if the module has been loaded. + + Returns: + bool: True if module is loaded, False otherwise + """ + return self._module is not None + + +class LazyLoaderMetaclass(type): + """Metaclass to handle lazy loading behavior""" + + def __call__(cls, *args, **kwargs): + if hasattr(cls, "_lazy_loader"): + return super().__call__(*args, **kwargs) + return super().__call__(*args, **kwargs) + + +class LazyClassLoader: + """ + A descriptor that creates the actual class only when accessed, + with proper inheritance support. + """ + + def __init__( + self, class_name: str, bases: tuple, namespace: Dict[str, Any] + ): + self.class_name = class_name + self.bases = bases + self.namespace = namespace + self._real_class: Optional[Type] = None + self._lock = threading.Lock() + + def _create_class(self) -> Type: + """Creates the actual class if it hasn't been created yet.""" + if self._real_class is None: + with self._lock: + if self._real_class is None: + # Update namespace to include metaclass + namespace = dict(self.namespace) + namespace["__metaclass__"] = LazyLoaderMetaclass + + # Create the class with metaclass + new_class = LazyLoaderMetaclass( + self.class_name, self.bases, namespace + ) + + # Store reference to this loader + new_class._lazy_loader = self + self._real_class = new_class + + return cast(Type, self._real_class) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Creates an instance of the lazy loaded class.""" + real_class = self._create_class() + # Use the metaclass __call__ method + return real_class(*args, **kwargs) + + def __instancecheck__(self, instance: Any) -> bool: + """Support for isinstance() checks""" + real_class = self._create_class() + return isinstance(instance, real_class) + + def __subclasscheck__(self, subclass: Type) -> bool: + """Support for issubclass() checks""" + real_class = self._create_class() + return issubclass(subclass, real_class) + + +def lazy_import(*names: str) -> Dict[str, LazyLoader]: + """ + Create multiple lazy loaders at once. + + Args: + *names: Module names to create lazy loaders for + + Returns: + Dict[str, LazyLoader]: Dictionary mapping module names to their lazy loaders + + Examples: + >>> modules = lazy_import('numpy', 'pandas', 'matplotlib.pyplot') + >>> np = modules['numpy'] + >>> pd = modules['pandas'] + >>> plt = modules['matplotlib.pyplot'] + """ + return {name.split(".")[-1]: LazyLoader(name) for name in names} + + +def lazy_import_decorator( + target: Union[Callable[..., T], Type[C]] +) -> Union[Callable[..., T], Type[C], LazyClassLoader]: + """ + Enhanced decorator that supports both lazy imports and lazy class loading. + """ + if isinstance(target, type): + # Store the original class details + namespace = { + name: value + for name, value in target.__dict__.items() + if not name.startswith("__") + or name in ("__init__", "__new__") + } + + # Create lazy loader + loader = LazyClassLoader( + target.__name__, target.__bases__, namespace + ) + + # Preserve class metadata + loader.__module__ = target.__module__ + loader.__doc__ = target.__doc__ + + # Add reference to original class + loader._original_class = target + + return loader + else: + # Handle function decoration + @functools.wraps(target) + def wrapper(*args: Any, **kwargs: Any) -> T: + return target(*args, **kwargs) + + return wrapper diff --git a/swarms/utils/litellm.py b/swarms/utils/litellm.py index 5bdd208d..8267e6be 100644 --- a/swarms/utils/litellm.py +++ b/swarms/utils/litellm.py @@ -8,6 +8,7 @@ except ImportError: from litellm import completion litellm.set_verbose = True + litellm.ssl_verify = False class LiteLLM: @@ -23,6 +24,7 @@ class LiteLLM: stream: bool = False, temperature: float = 0.5, max_tokens: int = 4000, + ssl_verify: bool = False, ): """ Initialize the LiteLLM with the given parameters. @@ -39,6 +41,7 @@ class LiteLLM: self.stream = stream self.temperature = temperature self.max_tokens = max_tokens + self.ssl_verify = ssl_verify def _prepare_messages(self, task: str) -> list: """ diff --git a/swarms/utils/openai_tts.py b/swarms/utils/openai_tts.py deleted file mode 100644 index 3cfcbd05..00000000 --- a/swarms/utils/openai_tts.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -from loguru import logger -import pygame -import requests -import tempfile -from openai import OpenAI - - -class OpenAITTS: - """ - A class to interact with OpenAI API and play the generated audio with improved streaming capabilities. - """ - - def __init__(self, *args, **kwargs): - self.client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY"), *args, **kwargs - ) - pygame.init() - - def run( - self, task: str, play_sound: bool = True, *args, **kwargs - ): - """ - Run a task with the OpenAI API and optionally play the generated audio with improved streaming. - - Args: - task (str): The task to be executed. - play_sound (bool): If True, play the generated audio. - - Returns: - None - """ - try: - response = self.client.audio.speech.create( - model="tts-1", - voice="nova", - input=task, - *args, - **kwargs, - ) - audio_url = response["url"] - logger.info("Task completed successfully.") - - if play_sound: - with tempfile.NamedTemporaryFile( - delete=False, suffix=".mp3" - ) as tmp_file: - with requests.get(audio_url, stream=True) as r: - r.raise_for_status() - for chunk in r.iter_content(chunk_size=8192): - tmp_file.write(chunk) - pygame.mixer.music.load(tmp_file.name) - pygame.mixer.music.play() - while pygame.mixer.music.get_busy(): - pygame.time.Clock().tick(10) - except Exception as e: - logger.error(f"Error during task execution: {str(e)}") - - -# client = OpenAITTS(api_key=os.getenv("OPENAI_API_KEY")) -# client.run("Hello world! This is a streaming test.", play_sound=True) - - -def text_to_speech( - task: str, play_sound: bool = True, *args, **kwargs -): - out = OpenAITTS().run( - task, play_sound=play_sound, *args, **kwargs - ) - return out - - -# print(text_to_speech(task="hello")) diff --git a/tree_swarm_test.py b/tree_swarm_test.py new file mode 100644 index 00000000..cb0d41c7 --- /dev/null +++ b/tree_swarm_test.py @@ -0,0 +1,42 @@ +from swarms.structs.tree_swarm import ForestSwarm, Tree, TreeAgent + + +agents_tree1 = [ + TreeAgent( + system_prompt="Stock Analysis Agent", + agent_name="Stock Analysis Agent", + ), + TreeAgent( + system_prompt="Financial Planning Agent", + agent_name="Financial Planning Agent", + ), + TreeAgent( + agent_name="Retirement Strategy Agent", + system_prompt="Retirement Strategy Agent", + ), +] + +agents_tree2 = [ + TreeAgent( + system_prompt="Tax Filing Agent", + agent_name="Tax Filing Agent", + ), + TreeAgent( + system_prompt="Investment Strategy Agent", + agent_name="Investment Strategy Agent", + ), + TreeAgent( + system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent" + ), +] + +# Create trees +tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1) +tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2) + +# Create the ForestSwarm +multi_agent_structure = ForestSwarm(trees=[tree1, tree2]) + +# Run a task +task = "Our company is incorporated in delaware, how do we do our taxes for free?" +multi_agent_structure.run(task) diff --git a/zpk.py b/zpk.py new file mode 100644 index 00000000..af37e01f --- /dev/null +++ b/zpk.py @@ -0,0 +1,206 @@ +from swarms import Agent +from loguru import logger +import random +import re + +# Configure loguru +logger.add("zkp_log.log", rotation="500 KB", retention="10 days", level="INFO") + + +class ProverAgent: + """ + Prover Agent for Zero Knowledge Proof. + + Responsibilities: + - Generate commitments based on a secret. + - Respond to challenges from the Verifier. + + Attributes: + agent (Agent): Swarms agent instance. + p (int): The prime modulus. + g (int): The generator. + x (int): The Prover's secret. + """ + + def __init__(self, p: int, g: int, secret: int): + self.p = p + self.g = g + self.x = secret # Prover's secret + self.agent = Agent( + agent_name="ProverAgent", + model_name="gpt-4o-mini", + max_loop=1, + interactive=False, + streaming_on=True, + system_prompt=( + "You are the Prover in a Zero Knowledge Proof (ZKP) system. " + "Your responsibilities are to generate commitments based on a secret value and " + "respond to challenges from the Verifier without revealing the secret. " + "Follow mathematical rules of modular arithmetic when performing computations." + ), + ) + logger.info("Initialized ProverAgent with p={}, g={}, secret={}", p, g, secret) + + def generate_commitment(self) -> tuple[int, int]: + """ + Generates a random commitment for the proof. + + Returns: + tuple[int, int]: The random value (r) and the commitment (t). + """ + r = random.randint(1, self.p - 2) + task = ( + f"Compute the commitment t = g^r % p for g={self.g}, r={r}, p={self.p}. " + "Return only the numerical value of t as an integer." + ) + t = self.agent.run(task=task) + t_value = self._extract_integer(t, "commitment") + logger.info("Prover generated commitment: r={}, t={}", r, t_value) + return r, t_value + + def _extract_integer(self, response: str, label: str) -> int: + """ + Extracts an integer from the LLM response. + + Args: + response (str): The response from the agent. + label (str): A label for logging purposes. + + Returns: + int: The extracted integer value. + """ + try: + # Use regex to find the first integer in the response + match = re.search(r"\b\d+\b", response) + if match: + value = int(match.group(0)) + return value + else: + raise ValueError(f"No integer found in {label} response: {response}") + except Exception as e: + logger.error("Failed to extract integer from {label} response: {response}") + raise ValueError(f"Invalid {label} response: {response}") from e + + def respond_to_challenge(self, r: int, c: int) -> int: + """ + Computes the response to a challenge. + + Args: + r (int): The random value used in the commitment. + c (int): The challenge issued by the Verifier. + + Returns: + int: The response (z). + """ + task = f"Compute the response z = (r + c * x) % (p-1) for r={r}, c={c}, x={self.x}, p={self.p}." + z = self.agent.run(task=task) + logger.info("Prover responded to challenge: z={}", z) + return int(z) + + +class VerifierAgent: + """ + Verifier Agent for Zero Knowledge Proof. + + Responsibilities: + - Issue challenges to the Prover. + - Verify the Prover's response. + + Attributes: + agent (Agent): Swarms agent instance. + p (int): The prime modulus. + g (int): The generator. + y (int): The public value from the Prover. + """ + + def __init__(self, p: int, g: int, y: int): + self.p = p + self.g = g + self.y = y # Public value + self.agent = Agent( + agent_name="VerifierAgent", + model_name="gpt-4o-mini", + max_loop=1, + interactive=False, + streaming_on=True, + system_prompt=( + "You are the Verifier in a Zero Knowledge Proof (ZKP) system. " + "Your responsibilities are to issue random challenges and verify the Prover's response. " + "Use modular arithmetic to check if the proof satisfies g^z % p == (t * y^c) % p." + ), + ) + logger.info("Initialized VerifierAgent with p={}, g={}, y={}", p, g, y) + + def issue_challenge(self) -> int: + """ + Issues a random challenge to the Prover. + + Returns: + int: The challenge value (c). + """ + c = random.randint(1, 10) + logger.info("Verifier issued challenge: c={}", c) + return c + + def verify_proof(self, t: int, z: int, c: int) -> bool: + """ + Verifies the Prover's response. + + Args: + t (int): The commitment from the Prover. + z (int): The response from the Prover. + c (int): The challenge issued to the Prover. + + Returns: + bool: True if the proof is valid, False otherwise. + """ + task = f"Verify if g^z % p == (t * y^c) % p for g={self.g}, z={z}, p={self.p}, t={t}, y={self.y}, c={c}." + verification_result = self.agent.run(task=task) + is_valid = verification_result.strip().lower() == "true" + logger.info("Verifier checked proof: t={}, z={}, c={}, valid={}", t, z, c, is_valid) + return is_valid + + +class CoordinatorAgent: + """ + Coordinator for orchestrating the Zero Knowledge Proof protocol. + + Responsibilities: + - Initialize parameters. + - Facilitate interaction between Prover and Verifier agents. + """ + + def __init__(self, p: int, g: int, secret: int): + self.p = p + self.g = g + self.prover = ProverAgent(p, g, secret) + y = pow(g, secret, p) # Public value + self.verifier = VerifierAgent(p, g, y) + logger.info("Coordinator initialized with p={}, g={}, secret={}", p, g, secret) + + def orchestrate(self) -> bool: + """ + Orchestrates the Zero Knowledge Proof protocol. + + Returns: + bool: True if the proof is valid, False otherwise. + """ + logger.info("Starting ZKP protocol orchestration.") + r, t = self.prover.generate_commitment() + c = self.verifier.issue_challenge() + z = self.prover.respond_to_challenge(r, c) + is_valid = self.verifier.verify_proof(t, z, c) + logger.info("ZKP protocol completed. Valid proof: {}", is_valid) + return is_valid + + +if __name__ == "__main__": + # Example parameters + p = 23 # Prime number + g = 5 # Generator + secret = 7 # Prover's secret + + # Initialize the Coordinator and run the protocol + coordinator = CoordinatorAgent(p, g, secret) + result = coordinator.orchestrate() + print(f"Zero Knowledge Proof Verification Result: {'Valid' if result else 'Invalid'}")