commit
a444440117
@ -0,0 +1,43 @@
|
|||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
#
|
||||||
|
# This workflow file requires a free account on Bearer.com to manage findings, notifications and more.
|
||||||
|
# See https://docs.bearer.com/guides/bearer-cloud/
|
||||||
|
name: Bearer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["master" ]
|
||||||
|
pull_request:
|
||||||
|
# The branches below must be a subset of the branches above
|
||||||
|
branches: ["master"]
|
||||||
|
schedule:
|
||||||
|
- cron: '24 22 * * 6'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read # for actions/checkout to fetch code
|
||||||
|
security-events: write # for github/codeql-action/upload-sarif to upload SARIF results
|
||||||
|
actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
bearer:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
# Checkout project source
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
# Scan code using Bearer CLI
|
||||||
|
- name: Run Report
|
||||||
|
id: report
|
||||||
|
uses: bearer/bearer-action@828eeb928ce2f4a7ca5ed57fb8b59508cb8c79bc
|
||||||
|
with:
|
||||||
|
api-key: ${{ secrets.BEARER_TOKEN }}
|
||||||
|
format: sarif
|
||||||
|
output: results.sarif
|
||||||
|
exit-code: 0
|
||||||
|
# Upload SARIF file generated in previous step
|
||||||
|
- name: Upload SARIF file
|
||||||
|
uses: github/codeql-action/upload-sarif@v3
|
||||||
|
with:
|
||||||
|
sarif_file: results.sarif
|
@ -0,0 +1,39 @@
|
|||||||
|
# Dependency Review Action
|
||||||
|
#
|
||||||
|
# This Action will scan dependency manifest files that change as part of a Pull Request,
|
||||||
|
# surfacing known-vulnerable versions of the packages declared or updated in the PR.
|
||||||
|
# Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable
|
||||||
|
# packages will be blocked from merging.
|
||||||
|
#
|
||||||
|
# Source repository: https://github.com/actions/dependency-review-action
|
||||||
|
# Public documentation: https://docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/about-dependency-review#dependency-review-enforcement
|
||||||
|
name: 'Dependency review'
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [ "master" ]
|
||||||
|
|
||||||
|
# If using a dependency submission action in this workflow this permission will need to be set to:
|
||||||
|
#
|
||||||
|
# permissions:
|
||||||
|
# contents: write
|
||||||
|
#
|
||||||
|
# https://docs.github.com/en/enterprise-cloud@latest/code-security/supply-chain-security/understanding-your-software-supply-chain/using-the-dependency-submission-api
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
# Write permissions for pull-requests are required for using the `comment-summary-in-pr` option, comment out if you aren't using this option
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
dependency-review:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: 'Checkout repository'
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: 'Dependency Review'
|
||||||
|
uses: actions/dependency-review-action@v4
|
||||||
|
# Commonly enabled options, see https://github.com/actions/dependency-review-action#configuration-options for all available options.
|
||||||
|
with:
|
||||||
|
comment-summary-in-pr: always
|
||||||
|
# fail-on-severity: moderate
|
||||||
|
# deny-licenses: GPL-1.0-or-later, LGPL-2.0-or-later
|
||||||
|
# retry-on-snapshot-warnings: true
|
@ -0,0 +1,18 @@
|
|||||||
|
name: Docker Image CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "master" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "master" ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Build the Docker image
|
||||||
|
run: docker build . --file Dockerfile --tag my-image-name:$(date +%s)
|
@ -0,0 +1,50 @@
|
|||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
# This workflow integrates Python Static Analyzer (Pysa) with
|
||||||
|
# GitHub's Code Scanning feature.
|
||||||
|
#
|
||||||
|
# Python Static Analyzer (Pysa) is a security-focused static
|
||||||
|
# analysis tool that tracks flows of data from where they
|
||||||
|
# originate to where they terminate in a dangerous location.
|
||||||
|
#
|
||||||
|
# See https://pyre-check.org/docs/pysa-basics/
|
||||||
|
|
||||||
|
name: Pysa
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches: [ "master" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "master" ]
|
||||||
|
schedule:
|
||||||
|
- cron: '43 5 * * 3'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pysa:
|
||||||
|
permissions:
|
||||||
|
actions: read
|
||||||
|
contents: read
|
||||||
|
security-events: write
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
|
||||||
|
- name: Run Pysa
|
||||||
|
uses: facebook/pysa-action@f46a63777e59268613bd6e2ff4e29f144ca9e88b
|
||||||
|
with:
|
||||||
|
# To customize these inputs:
|
||||||
|
# See https://github.com/facebook/pysa-action#inputs
|
||||||
|
repo-directory: './'
|
||||||
|
requirements-path: 'requirements.txt'
|
||||||
|
infer-types: true
|
||||||
|
include-default-sapp-filters: true
|
@ -0,0 +1,34 @@
|
|||||||
|
name: Python Package using Conda
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-linux:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
max-parallel: 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Add conda to system path
|
||||||
|
run: |
|
||||||
|
# $CONDA is an environment variable pointing to the root of the miniconda directory
|
||||||
|
echo $CONDA/bin >> $GITHUB_PATH
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
conda env update --file environment.yml --name base
|
||||||
|
- name: Lint with flake8
|
||||||
|
run: |
|
||||||
|
conda install flake8
|
||||||
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
|
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||||
|
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||||
|
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
conda install pytest
|
||||||
|
pytest
|
@ -0,0 +1,49 @@
|
|||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
# This workflow file requires a free account on Semgrep.dev to
|
||||||
|
# manage rules, file ignores, notifications, and more.
|
||||||
|
#
|
||||||
|
# See https://semgrep.dev/docs
|
||||||
|
|
||||||
|
name: Semgrep
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "master" ]
|
||||||
|
pull_request:
|
||||||
|
# The branches below must be a subset of the branches above
|
||||||
|
branches: [ "master" ]
|
||||||
|
schedule:
|
||||||
|
- cron: '19 7 * * 3'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
semgrep:
|
||||||
|
permissions:
|
||||||
|
contents: read # for actions/checkout to fetch code
|
||||||
|
security-events: write # for github/codeql-action/upload-sarif to upload SARIF results
|
||||||
|
actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status
|
||||||
|
name: Scan
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
# Checkout project source
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# Scan code using project's configuration on https://semgrep.dev/manage
|
||||||
|
- uses: returntocorp/semgrep-action@fcd5ab7459e8d91cb1777481980d1b18b4fc6735
|
||||||
|
with:
|
||||||
|
publishToken: ${{ secrets.SEMGREP_APP_TOKEN }}
|
||||||
|
publishDeployment: ${{ secrets.SEMGREP_DEPLOYMENT_ID }}
|
||||||
|
generateSarif: "1"
|
||||||
|
|
||||||
|
# Upload SARIF file generated in previous step
|
||||||
|
- name: Upload SARIF file
|
||||||
|
uses: github/codeql-action/upload-sarif@v3
|
||||||
|
with:
|
||||||
|
sarif_file: semgrep.sarif
|
||||||
|
if: always()
|
@ -0,0 +1,48 @@
|
|||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
name: trivy
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "master" ]
|
||||||
|
pull_request:
|
||||||
|
# The branches below must be a subset of the branches above
|
||||||
|
branches: [ "master" ]
|
||||||
|
schedule:
|
||||||
|
- cron: '31 0 * * 5'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
permissions:
|
||||||
|
contents: read # for actions/checkout to fetch code
|
||||||
|
security-events: write # for github/codeql-action/upload-sarif to upload SARIF results
|
||||||
|
actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status
|
||||||
|
name: Build
|
||||||
|
runs-on: "ubuntu-20.04"
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build an image from Dockerfile
|
||||||
|
run: |
|
||||||
|
docker build -t docker.io/my-organization/my-app:${{ github.sha }} .
|
||||||
|
|
||||||
|
- name: Run Trivy vulnerability scanner
|
||||||
|
uses: aquasecurity/trivy-action@7b7aa264d83dc58691451798b4d117d53d21edfe
|
||||||
|
with:
|
||||||
|
image-ref: 'docker.io/my-organization/my-app:${{ github.sha }}'
|
||||||
|
format: 'template'
|
||||||
|
template: '@/contrib/sarif.tpl'
|
||||||
|
output: 'trivy-results.sarif'
|
||||||
|
severity: 'CRITICAL,HIGH'
|
||||||
|
|
||||||
|
- name: Upload Trivy scan results to GitHub Security tab
|
||||||
|
uses: github/codeql-action/upload-sarif@v3
|
||||||
|
with:
|
||||||
|
sarif_file: 'trivy-results.sarif'
|
@ -0,0 +1,629 @@
|
|||||||
|
import os
|
||||||
|
from fastapi import (
|
||||||
|
FastAPI,
|
||||||
|
HTTPException,
|
||||||
|
status,
|
||||||
|
Query,
|
||||||
|
BackgroundTasks,
|
||||||
|
)
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from loguru import logger
|
||||||
|
import uvicorn
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Configure Loguru
|
||||||
|
logger.add(
|
||||||
|
"logs/api_{time}.log",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time} {level} {message}",
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatus(str, Enum):
|
||||||
|
"""Enum for agent status."""
|
||||||
|
|
||||||
|
IDLE = "idle"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
ERROR = "error"
|
||||||
|
MAINTENANCE = "maintenance"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Configuration model for creating a new agent."""
|
||||||
|
|
||||||
|
agent_name: str = Field(..., description="Name of the agent")
|
||||||
|
model_name: str = Field(
|
||||||
|
...,
|
||||||
|
description="Name of the llm you want to use provided by litellm",
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
default="", description="Description of the agent's purpose"
|
||||||
|
)
|
||||||
|
system_prompt: str = Field(
|
||||||
|
..., description="System prompt for the agent"
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="gpt-4", description="Model name to use"
|
||||||
|
)
|
||||||
|
temperature: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
ge=0.0,
|
||||||
|
le=2.0,
|
||||||
|
description="Temperature for the model",
|
||||||
|
)
|
||||||
|
max_loops: int = Field(
|
||||||
|
default=1, ge=1, description="Maximum number of loops"
|
||||||
|
)
|
||||||
|
autosave: bool = Field(
|
||||||
|
default=True, description="Enable autosave"
|
||||||
|
)
|
||||||
|
dashboard: bool = Field(
|
||||||
|
default=False, description="Enable dashboard"
|
||||||
|
)
|
||||||
|
verbose: bool = Field(
|
||||||
|
default=True, description="Enable verbose output"
|
||||||
|
)
|
||||||
|
dynamic_temperature_enabled: bool = Field(
|
||||||
|
default=True, description="Enable dynamic temperature"
|
||||||
|
)
|
||||||
|
user_name: str = Field(
|
||||||
|
default="default_user", description="Username for the agent"
|
||||||
|
)
|
||||||
|
retry_attempts: int = Field(
|
||||||
|
default=1, ge=1, description="Number of retry attempts"
|
||||||
|
)
|
||||||
|
context_length: int = Field(
|
||||||
|
default=200000, ge=1000, description="Context length"
|
||||||
|
)
|
||||||
|
output_type: str = Field(
|
||||||
|
default="string", description="Output type (string or json)"
|
||||||
|
)
|
||||||
|
streaming_on: bool = Field(
|
||||||
|
default=False, description="Enable streaming"
|
||||||
|
)
|
||||||
|
tags: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Tags for categorizing the agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentUpdate(BaseModel):
|
||||||
|
"""Model for updating agent configuration."""
|
||||||
|
|
||||||
|
description: Optional[str] = None
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_loops: Optional[int] = None
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
status: Optional[AgentStatus] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSummary(BaseModel):
|
||||||
|
"""Summary model for agent listing."""
|
||||||
|
|
||||||
|
agent_id: UUID
|
||||||
|
agent_name: str
|
||||||
|
description: str
|
||||||
|
created_at: datetime
|
||||||
|
last_used: datetime
|
||||||
|
total_completions: int
|
||||||
|
tags: List[str]
|
||||||
|
status: AgentStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMetrics(BaseModel):
|
||||||
|
"""Model for agent performance metrics."""
|
||||||
|
|
||||||
|
total_completions: int
|
||||||
|
average_response_time: float
|
||||||
|
error_rate: float
|
||||||
|
last_24h_completions: int
|
||||||
|
total_tokens_used: int
|
||||||
|
uptime_percentage: float
|
||||||
|
success_rate: float
|
||||||
|
peak_tokens_per_minute: int
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
"""Model for completion requests."""
|
||||||
|
|
||||||
|
prompt: str = Field(..., description="The prompt to process")
|
||||||
|
agent_id: UUID = Field(..., description="ID of the agent to use")
|
||||||
|
max_tokens: Optional[int] = Field(
|
||||||
|
None, description="Maximum tokens to generate"
|
||||||
|
)
|
||||||
|
temperature_override: Optional[float] = None
|
||||||
|
stream: bool = Field(
|
||||||
|
default=False, description="Enable streaming response"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
"""Model for completion responses."""
|
||||||
|
|
||||||
|
agent_id: UUID
|
||||||
|
response: str
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
timestamp: datetime
|
||||||
|
processing_time: float
|
||||||
|
token_usage: Dict[str, int]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStore:
|
||||||
|
"""Enhanced store for managing agents."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.agents: Dict[UUID, Agent] = {}
|
||||||
|
self.agent_metadata: Dict[UUID, Dict[str, Any]] = {}
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||||
|
self._ensure_directories()
|
||||||
|
|
||||||
|
def _ensure_directories(self):
|
||||||
|
"""Ensure required directories exist."""
|
||||||
|
Path("logs").mkdir(exist_ok=True)
|
||||||
|
Path("states").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
async def create_agent(self, config: AgentConfig) -> UUID:
|
||||||
|
"""Create a new agent with the given configuration."""
|
||||||
|
try:
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
agent_name=config.agent_name,
|
||||||
|
system_prompt=config.system_prompt,
|
||||||
|
model_name=config.model_name,
|
||||||
|
max_loops=config.max_loops,
|
||||||
|
autosave=config.autosave,
|
||||||
|
dashboard=config.dashboard,
|
||||||
|
verbose=config.verbose,
|
||||||
|
dynamic_temperature_enabled=config.dynamic_temperature_enabled,
|
||||||
|
saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
||||||
|
user_name=config.user_name,
|
||||||
|
retry_attempts=config.retry_attempts,
|
||||||
|
context_length=config.context_length,
|
||||||
|
return_step_meta=True,
|
||||||
|
output_type="str",
|
||||||
|
streaming_on=config.streaming_on,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_id = uuid4()
|
||||||
|
self.agents[agent_id] = agent
|
||||||
|
self.agent_metadata[agent_id] = {
|
||||||
|
"description": config.description,
|
||||||
|
"created_at": datetime.utcnow(),
|
||||||
|
"last_used": datetime.utcnow(),
|
||||||
|
"total_completions": 0,
|
||||||
|
"tags": config.tags,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"error_count": 0,
|
||||||
|
"response_times": [],
|
||||||
|
"status": AgentStatus.IDLE,
|
||||||
|
"start_time": datetime.utcnow(),
|
||||||
|
"downtime": timedelta(),
|
||||||
|
"successful_completions": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Created agent with ID: {agent_id}")
|
||||||
|
return agent_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating agent: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to create agent: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_agent(self, agent_id: UUID) -> Agent:
|
||||||
|
"""Retrieve an agent by ID."""
|
||||||
|
agent = self.agents.get(agent_id)
|
||||||
|
if not agent:
|
||||||
|
logger.error(f"Agent not found: {agent_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Agent {agent_id} not found",
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
async def update_agent(
|
||||||
|
self, agent_id: UUID, update: AgentUpdate
|
||||||
|
) -> None:
|
||||||
|
"""Update agent configuration."""
|
||||||
|
agent = await self.get_agent(agent_id)
|
||||||
|
metadata = self.agent_metadata[agent_id]
|
||||||
|
|
||||||
|
if update.system_prompt:
|
||||||
|
agent.system_prompt = update.system_prompt
|
||||||
|
if update.temperature is not None:
|
||||||
|
agent.llm.temperature = update.temperature
|
||||||
|
if update.max_loops is not None:
|
||||||
|
agent.max_loops = update.max_loops
|
||||||
|
if update.tags is not None:
|
||||||
|
metadata["tags"] = update.tags
|
||||||
|
if update.description is not None:
|
||||||
|
metadata["description"] = update.description
|
||||||
|
if update.status is not None:
|
||||||
|
metadata["status"] = update.status
|
||||||
|
if update.status == AgentStatus.MAINTENANCE:
|
||||||
|
metadata["downtime"] += (
|
||||||
|
datetime.utcnow() - metadata["last_used"]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Updated agent {agent_id}")
|
||||||
|
|
||||||
|
async def list_agents(
|
||||||
|
self,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
status: Optional[AgentStatus] = None,
|
||||||
|
) -> List[AgentSummary]:
|
||||||
|
"""List all agents, optionally filtered by tags and status."""
|
||||||
|
summaries = []
|
||||||
|
for agent_id, agent in self.agents.items():
|
||||||
|
metadata = self.agent_metadata[agent_id]
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if tags and not any(
|
||||||
|
tag in metadata["tags"] for tag in tags
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if status and metadata["status"] != status:
|
||||||
|
continue
|
||||||
|
|
||||||
|
summaries.append(
|
||||||
|
AgentSummary(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_name=agent.agent_name,
|
||||||
|
description=metadata["description"],
|
||||||
|
created_at=metadata["created_at"],
|
||||||
|
last_used=metadata["last_used"],
|
||||||
|
total_completions=metadata["total_completions"],
|
||||||
|
tags=metadata["tags"],
|
||||||
|
status=metadata["status"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
async def get_agent_metrics(self, agent_id: UUID) -> AgentMetrics:
|
||||||
|
"""Get performance metrics for an agent."""
|
||||||
|
metadata = self.agent_metadata[agent_id]
|
||||||
|
response_times = metadata["response_times"]
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
total_time = datetime.utcnow() - metadata["start_time"]
|
||||||
|
uptime = total_time - metadata["downtime"]
|
||||||
|
uptime_percentage = (
|
||||||
|
uptime.total_seconds() / total_time.total_seconds()
|
||||||
|
) * 100
|
||||||
|
|
||||||
|
success_rate = (
|
||||||
|
metadata["successful_completions"]
|
||||||
|
/ metadata["total_completions"]
|
||||||
|
* 100
|
||||||
|
if metadata["total_completions"] > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentMetrics(
|
||||||
|
total_completions=metadata["total_completions"],
|
||||||
|
average_response_time=(
|
||||||
|
sum(response_times) / len(response_times)
|
||||||
|
if response_times
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
error_rate=(
|
||||||
|
metadata["error_count"]
|
||||||
|
/ metadata["total_completions"]
|
||||||
|
if metadata["total_completions"] > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
last_24h_completions=sum(
|
||||||
|
1
|
||||||
|
for t in response_times
|
||||||
|
if (datetime.utcnow() - t).days < 1
|
||||||
|
),
|
||||||
|
total_tokens_used=metadata["total_tokens"],
|
||||||
|
uptime_percentage=uptime_percentage,
|
||||||
|
success_rate=success_rate,
|
||||||
|
peak_tokens_per_minute=max(
|
||||||
|
metadata.get("tokens_per_minute", [0])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def clone_agent(
|
||||||
|
self, agent_id: UUID, new_name: str
|
||||||
|
) -> UUID:
|
||||||
|
"""Clone an existing agent with a new name."""
|
||||||
|
original_agent = await self.get_agent(agent_id)
|
||||||
|
original_metadata = self.agent_metadata[agent_id]
|
||||||
|
|
||||||
|
config = AgentConfig(
|
||||||
|
agent_name=new_name,
|
||||||
|
description=f"Clone of {original_agent.agent_name}",
|
||||||
|
system_prompt=original_agent.system_prompt,
|
||||||
|
model_name=original_agent.llm.model_name,
|
||||||
|
temperature=original_agent.llm.temperature,
|
||||||
|
max_loops=original_agent.max_loops,
|
||||||
|
tags=original_metadata["tags"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.create_agent(config)
|
||||||
|
|
||||||
|
async def delete_agent(self, agent_id: UUID) -> None:
|
||||||
|
"""Delete an agent."""
|
||||||
|
if agent_id not in self.agents:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Agent {agent_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up any resources
|
||||||
|
agent = self.agents[agent_id]
|
||||||
|
if agent.autosave and os.path.exists(agent.saved_state_path):
|
||||||
|
os.remove(agent.saved_state_path)
|
||||||
|
|
||||||
|
del self.agents[agent_id]
|
||||||
|
del self.agent_metadata[agent_id]
|
||||||
|
logger.info(f"Deleted agent {agent_id}")
|
||||||
|
|
||||||
|
async def process_completion(
|
||||||
|
self,
|
||||||
|
agent: Agent,
|
||||||
|
prompt: str,
|
||||||
|
agent_id: UUID,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature_override: Optional[float] = None,
|
||||||
|
) -> CompletionResponse:
|
||||||
|
"""Process a completion request using the specified agent."""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
metadata = self.agent_metadata[agent_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update agent status
|
||||||
|
metadata["status"] = AgentStatus.PROCESSING
|
||||||
|
metadata["last_used"] = start_time
|
||||||
|
|
||||||
|
# Apply temporary overrides if specified
|
||||||
|
original_temp = agent.llm.temperature
|
||||||
|
if temperature_override is not None:
|
||||||
|
agent.llm.temperature = temperature_override
|
||||||
|
|
||||||
|
# Process the completion
|
||||||
|
response = agent.run(prompt)
|
||||||
|
|
||||||
|
# Reset overrides
|
||||||
|
if temperature_override is not None:
|
||||||
|
agent.llm.temperature = original_temp
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
processing_time = (
|
||||||
|
datetime.utcnow() - start_time
|
||||||
|
).total_seconds()
|
||||||
|
metadata["response_times"].append(processing_time)
|
||||||
|
metadata["total_completions"] += 1
|
||||||
|
metadata["successful_completions"] += 1
|
||||||
|
|
||||||
|
# Estimate token usage (this is a rough estimate)
|
||||||
|
prompt_tokens = len(prompt.split()) * 1.3
|
||||||
|
completion_tokens = len(response.split()) * 1.3
|
||||||
|
total_tokens = int(prompt_tokens + completion_tokens)
|
||||||
|
metadata["total_tokens"] += total_tokens
|
||||||
|
|
||||||
|
# Update tokens per minute tracking
|
||||||
|
current_minute = datetime.utcnow().replace(
|
||||||
|
second=0, microsecond=0
|
||||||
|
)
|
||||||
|
if "tokens_per_minute" not in metadata:
|
||||||
|
metadata["tokens_per_minute"] = {}
|
||||||
|
metadata["tokens_per_minute"][current_minute] = (
|
||||||
|
metadata["tokens_per_minute"].get(current_minute, 0)
|
||||||
|
+ total_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return CompletionResponse(
|
||||||
|
agent_id=agent_id,
|
||||||
|
response=response,
|
||||||
|
metadata={
|
||||||
|
"agent_name": agent.agent_name,
|
||||||
|
"model_name": agent.llm.model_name,
|
||||||
|
"temperature": agent.llm.temperature,
|
||||||
|
},
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
processing_time=processing_time,
|
||||||
|
token_usage={
|
||||||
|
"prompt_tokens": int(prompt_tokens),
|
||||||
|
"completion_tokens": int(completion_tokens),
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
metadata["error_count"] += 1
|
||||||
|
metadata["status"] = AgentStatus.ERROR
|
||||||
|
logger.error(
|
||||||
|
f"Error in completion processing: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error processing completion: {str(e)}",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
metadata["status"] = AgentStatus.IDLE
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmsAPI:
|
||||||
|
"""Enhanced API class for Swarms agent integration."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.app = FastAPI(
|
||||||
|
title="Swarms Agent API",
|
||||||
|
description="Production-grade API for Swarms agent interaction",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url="/v1/docs",
|
||||||
|
redoc_url="/v1/redoc",
|
||||||
|
)
|
||||||
|
self.store = AgentStore()
|
||||||
|
# Configure CORS
|
||||||
|
self.app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=[
|
||||||
|
"*"
|
||||||
|
], # Configure appropriately for production
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._setup_routes()
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
"""Set up API routes."""
|
||||||
|
|
||||||
|
@self.app.post("/v1/agent", response_model=Dict[str, UUID])
|
||||||
|
async def create_agent(config: AgentConfig):
|
||||||
|
"""Create a new agent with the specified configuration."""
|
||||||
|
agent_id = await self.store.create_agent(config)
|
||||||
|
return {"agent_id": agent_id}
|
||||||
|
|
||||||
|
@self.app.get("/v1/agents", response_model=List[AgentSummary])
|
||||||
|
async def list_agents(
|
||||||
|
tags: Optional[List[str]] = Query(None),
|
||||||
|
status: Optional[AgentStatus] = None,
|
||||||
|
):
|
||||||
|
"""List all agents, optionally filtered by tags and status."""
|
||||||
|
return await self.store.list_agents(tags, status)
|
||||||
|
|
||||||
|
@self.app.patch(
|
||||||
|
"/v1/agent/{agent_id}", response_model=Dict[str, str]
|
||||||
|
)
|
||||||
|
async def update_agent(agent_id: UUID, update: AgentUpdate):
|
||||||
|
"""Update an existing agent's configuration."""
|
||||||
|
await self.store.update_agent(agent_id, update)
|
||||||
|
return {"status": "updated"}
|
||||||
|
|
||||||
|
@self.app.get(
|
||||||
|
"/v1/agent/{agent_id}/metrics",
|
||||||
|
response_model=AgentMetrics,
|
||||||
|
)
|
||||||
|
async def get_agent_metrics(agent_id: UUID):
|
||||||
|
"""Get performance metrics for a specific agent."""
|
||||||
|
return await self.store.get_agent_metrics(agent_id)
|
||||||
|
|
||||||
|
@self.app.post(
|
||||||
|
"/v1/agent/{agent_id}/clone",
|
||||||
|
response_model=Dict[str, UUID],
|
||||||
|
)
|
||||||
|
async def clone_agent(agent_id: UUID, new_name: str):
|
||||||
|
"""Clone an existing agent with a new name."""
|
||||||
|
new_id = await self.store.clone_agent(agent_id, new_name)
|
||||||
|
return {"agent_id": new_id}
|
||||||
|
|
||||||
|
@self.app.delete("/v1/agent/{agent_id}")
|
||||||
|
async def delete_agent(agent_id: UUID):
|
||||||
|
"""Delete an agent."""
|
||||||
|
await self.store.delete_agent(agent_id)
|
||||||
|
return {"status": "deleted"}
|
||||||
|
|
||||||
|
@self.app.post(
|
||||||
|
"/v1/agent/completions", response_model=CompletionResponse
|
||||||
|
)
|
||||||
|
async def create_completion(
|
||||||
|
request: CompletionRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
):
|
||||||
|
"""Process a completion request with the specified agent."""
|
||||||
|
try:
|
||||||
|
agent = await self.store.get_agent(request.agent_id)
|
||||||
|
|
||||||
|
# Process completion
|
||||||
|
response = await self.store.process_completion(
|
||||||
|
agent,
|
||||||
|
request.prompt,
|
||||||
|
request.agent_id,
|
||||||
|
request.max_tokens,
|
||||||
|
request.temperature_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schedule background cleanup
|
||||||
|
background_tasks.add_task(
|
||||||
|
self._cleanup_old_metrics, request.agent_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing completion: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error processing completion: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.app.get("/v1/agent/{agent_id}/status")
|
||||||
|
async def get_agent_status(agent_id: UUID):
|
||||||
|
"""Get the current status of an agent."""
|
||||||
|
metadata = self.store.agent_metadata.get(agent_id)
|
||||||
|
if not metadata:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Agent {agent_id} not found",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"status": metadata["status"],
|
||||||
|
"last_used": metadata["last_used"],
|
||||||
|
"total_completions": metadata["total_completions"],
|
||||||
|
"error_count": metadata["error_count"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _cleanup_old_metrics(self, agent_id: UUID):
|
||||||
|
"""Clean up old metrics data to prevent memory bloat."""
|
||||||
|
metadata = self.store.agent_metadata.get(agent_id)
|
||||||
|
if metadata:
|
||||||
|
# Keep only last 24 hours of response times
|
||||||
|
cutoff = datetime.utcnow() - timedelta(days=1)
|
||||||
|
metadata["response_times"] = [
|
||||||
|
t
|
||||||
|
for t in metadata["response_times"]
|
||||||
|
if isinstance(t, (int, float))
|
||||||
|
and t > cutoff.timestamp()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Clean up old tokens per minute data
|
||||||
|
if "tokens_per_minute" in metadata:
|
||||||
|
metadata["tokens_per_minute"] = {
|
||||||
|
k: v
|
||||||
|
for k, v in metadata["tokens_per_minute"].items()
|
||||||
|
if k > cutoff
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI application."""
|
||||||
|
api = SwarmsAPI()
|
||||||
|
return api.app
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Configure uvicorn logging
|
||||||
|
logger.info("API Starting")
|
||||||
|
uvicorn.run(
|
||||||
|
"main:create_app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
reload=True,
|
||||||
|
workers=4,
|
||||||
|
)
|
@ -0,0 +1,107 @@
|
|||||||
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Configure loguru
|
||||||
|
logger.add(
|
||||||
|
"api_tests_{time}.log",
|
||||||
|
rotation="100 MB",
|
||||||
|
level="DEBUG",
|
||||||
|
format="{time} {level} {message}",
|
||||||
|
)
|
||||||
|
|
||||||
|
BASE_URL = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_agent():
|
||||||
|
"""Test creating a new agent."""
|
||||||
|
logger.info("Testing agent creation")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"agent_name": "Test Agent",
|
||||||
|
"system_prompt": "You are a helpful assistant",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"description": "Test agent",
|
||||||
|
"tags": ["test"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{BASE_URL}/agent", json=payload)
|
||||||
|
logger.debug(f"Create response: {response.json()}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.success("Successfully created agent")
|
||||||
|
return response.json()["agent_id"]
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to create agent: {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_agents():
|
||||||
|
"""Test listing all agents."""
|
||||||
|
logger.info("Testing agent listing")
|
||||||
|
|
||||||
|
response = requests.get(f"{BASE_URL}/agents")
|
||||||
|
logger.debug(f"List response: {response.json()}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.success(f"Found {len(response.json())} agents")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to list agents: {response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion(agent_id):
|
||||||
|
"""Test running a completion."""
|
||||||
|
logger.info("Testing completion")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"prompt": "What is the weather like today?",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/agent/completions", json=payload
|
||||||
|
)
|
||||||
|
logger.debug(f"Completion response: {response.json()}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.success("Successfully got completion")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to get completion: {response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_agent(agent_id):
|
||||||
|
"""Test deleting an agent."""
|
||||||
|
logger.info("Testing agent deletion")
|
||||||
|
|
||||||
|
response = requests.delete(f"{BASE_URL}/agent/{agent_id}")
|
||||||
|
logger.debug(f"Delete response: {response.json()}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.success("Successfully deleted agent")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to delete agent: {response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests():
|
||||||
|
"""Run all tests in sequence."""
|
||||||
|
logger.info("Starting API tests")
|
||||||
|
|
||||||
|
# Create agent and get ID
|
||||||
|
agent_id = test_create_agent()
|
||||||
|
if not agent_id:
|
||||||
|
logger.error("Cannot continue tests without agent ID")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Wait a bit for agent to be ready
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Run other tests
|
||||||
|
test_list_agents()
|
||||||
|
test_completion(agent_id)
|
||||||
|
test_delete_agent(agent_id)
|
||||||
|
|
||||||
|
logger.info("Tests completed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -0,0 +1,898 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Union, Optional
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import struct
|
||||||
|
|
||||||
|
|
||||||
|
from enum import auto
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
import wave
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from loguru import logger
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
"""Configuration for the enhanced BytePredictor model."""
|
||||||
|
|
||||||
|
vocab_size: int = 256 # Standard byte range
|
||||||
|
hidden_size: int = 1024
|
||||||
|
num_layers: int = 12
|
||||||
|
num_key_value_heads: int = 8 # For multi-query attention
|
||||||
|
num_query_heads: int = 32 # More query heads than kv heads
|
||||||
|
dropout: float = 0.1
|
||||||
|
max_sequence_length: int = 8192
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
layer_norm_eps: float = 1e-5
|
||||||
|
vocab_parallel: bool = False
|
||||||
|
qk_norm: bool = True
|
||||||
|
qk_norm_scale: float = None
|
||||||
|
attention_bias: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MultiQueryAttention(nn.Module):
|
||||||
|
"""Fixed Multi-Query Attention implementation."""
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_query_heads = config.num_query_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.head_dim = config.hidden_size // config.num_query_heads
|
||||||
|
self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
config.hidden_size, config.num_query_heads * self.head_dim
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_key_value_heads * self.head_dim,
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_key_value_heads * self.head_dim,
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
config.num_query_heads * self.head_dim, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qk_norm = config.qk_norm
|
||||||
|
if self.qk_norm:
|
||||||
|
self.q_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# Project and reshape
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Reshape to [seq_len, batch, heads, head_dim]
|
||||||
|
q = q.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_query_heads,
|
||||||
|
self.head_dim,
|
||||||
|
).permute(1, 0, 2, 3)
|
||||||
|
k = k.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.head_dim,
|
||||||
|
).permute(1, 0, 2, 3)
|
||||||
|
v = v.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.head_dim,
|
||||||
|
).permute(1, 0, 2, 3)
|
||||||
|
|
||||||
|
# Apply rotary embeddings
|
||||||
|
# q, k = self.rotary(q, k, seq_length)
|
||||||
|
|
||||||
|
# Apply QK normalization if enabled
|
||||||
|
if self.qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# Handle MQA head expansion
|
||||||
|
if self.num_key_value_heads != self.num_query_heads:
|
||||||
|
k = k.repeat_interleave(
|
||||||
|
self.num_query_heads // self.num_key_value_heads,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
v = v.repeat_interleave(
|
||||||
|
self.num_query_heads // self.num_key_value_heads,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
# Reshape for matmul: [batch, heads, seq_length, head_dim]
|
||||||
|
q = q.permute(1, 2, 0, 3)
|
||||||
|
k = k.permute(1, 2, 0, 3)
|
||||||
|
v = v.permute(1, 2, 0, 3)
|
||||||
|
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
output = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
# Reshape back to [batch, seq_length, hidden_size]
|
||||||
|
output = (
|
||||||
|
output.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, seq_length, -1)
|
||||||
|
)
|
||||||
|
output = self.o_proj(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedBytePredictor(nn.Module):
|
||||||
|
"""Enhanced byte prediction model with state-of-the-art features."""
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Token embeddings
|
||||||
|
self.tok_embeddings = nn.Embedding(
|
||||||
|
config.vocab_size, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer layers
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"attention": MultiQueryAttention(config),
|
||||||
|
"attention_norm": nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
),
|
||||||
|
"feed_forward": nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
4 * config.hidden_size,
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(
|
||||||
|
4 * config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"feed_forward_norm": nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for _ in range(config.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.output = nn.Linear(
|
||||||
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module: nn.Module) -> None:
|
||||||
|
"""Initialize weights with scaled normal distribution."""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Tensor of shape (batch_size, sequence_length)
|
||||||
|
attention_mask: Optional attention mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of logits with shape (batch_size, sequence_length, vocab_size)
|
||||||
|
"""
|
||||||
|
hidden_states = self.tok_embeddings(input_ids)
|
||||||
|
|
||||||
|
# Create causal mask if needed
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.triu(
|
||||||
|
torch.ones(
|
||||||
|
(input_ids.size(1), input_ids.size(1)),
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=torch.bool,
|
||||||
|
),
|
||||||
|
diagonal=1,
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.masked_fill(
|
||||||
|
attention_mask == 1, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply transformer layers
|
||||||
|
for layer in self.layers:
|
||||||
|
# Attention block
|
||||||
|
hidden_states = hidden_states + layer["attention"](
|
||||||
|
layer["attention_norm"](hidden_states), attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Feed-forward block
|
||||||
|
hidden_states = hidden_states + layer["feed_forward"](
|
||||||
|
layer["feed_forward_norm"](hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
logits = self.output(hidden_states)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
target_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute cross entropy loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token ids
|
||||||
|
target_ids: Target token ids
|
||||||
|
attention_mask: Optional attention mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss value
|
||||||
|
"""
|
||||||
|
logits = self(input_ids, attention_mask)
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
rearrange(logits, "b s v -> (b s) v"),
|
||||||
|
rearrange(target_ids, "b s -> (b s)"),
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate new tokens autoregressively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Starting sequence
|
||||||
|
max_new_tokens: Number of tokens to generate
|
||||||
|
temperature: Sampling temperature
|
||||||
|
top_k: K for top-k sampling
|
||||||
|
top_p: P for nucleus sampling
|
||||||
|
repetition_penalty: Penalty for repeating tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated sequence
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
generated = input_ids.clone()
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if generated.size(1) >= self.config.max_sequence_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits = self(generated)[:, -1, :]
|
||||||
|
|
||||||
|
# Apply temperature
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# Apply repetition penalty
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
for i in range(batch_size):
|
||||||
|
for token_id in set(generated[i].tolist()):
|
||||||
|
logits[i, token_id] /= repetition_penalty
|
||||||
|
|
||||||
|
# Apply top-k sampling
|
||||||
|
if top_k is not None:
|
||||||
|
indices_to_remove = (
|
||||||
|
logits
|
||||||
|
< torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
)
|
||||||
|
logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Apply nucleus (top-p) sampling
|
||||||
|
if top_p is not None:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(
|
||||||
|
logits, descending=True
|
||||||
|
)
|
||||||
|
cumulative_probs = torch.cumsum(
|
||||||
|
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = (
|
||||||
|
sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
)
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = torch.zeros_like(
|
||||||
|
logits, dtype=torch.bool
|
||||||
|
)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Sample next token
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
# Append to sequence
|
||||||
|
generated = torch.cat([generated, next_token], dim=1)
|
||||||
|
|
||||||
|
return generated
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
):
|
||||||
|
tensor_data = self._generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tensor_to_data(tensor_data)
|
||||||
|
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
# from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
BINARY = "binary"
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDetokenizer:
|
||||||
|
"""Utility class for converting model output bytes back to original data formats."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
|
||||||
|
"""Convert model output tensor to bytes."""
|
||||||
|
# Convert logits/probabilities to byte values
|
||||||
|
if tensor.dim() > 1:
|
||||||
|
# If we have logits, convert to byte indices
|
||||||
|
byte_indices = tensor.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
byte_indices = tensor
|
||||||
|
|
||||||
|
# Convert to Python bytes
|
||||||
|
return bytes(
|
||||||
|
byte_indices.cpu().numpy().astype(np.uint8).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_text(byte_sequence: bytes) -> str:
|
||||||
|
"""Convert bytes to text."""
|
||||||
|
try:
|
||||||
|
return byte_sequence.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Try with error handling
|
||||||
|
return byte_sequence.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_image(
|
||||||
|
byte_sequence: bytes,
|
||||||
|
mode: str = "RGB",
|
||||||
|
size: Optional[tuple] = None,
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Convert bytes to image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_sequence: Raw image bytes
|
||||||
|
mode: Image mode (RGB, RGBA, L, etc.)
|
||||||
|
size: Optional tuple of (width, height)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to load as-is first (for standard image formats)
|
||||||
|
img = Image.open(io.BytesIO(byte_sequence))
|
||||||
|
if size:
|
||||||
|
img = img.resize(size)
|
||||||
|
return img
|
||||||
|
except:
|
||||||
|
# If failed, assume raw pixel data
|
||||||
|
if not size:
|
||||||
|
# Try to determine size from byte count
|
||||||
|
pixel_count = len(byte_sequence) // len(mode)
|
||||||
|
size = (
|
||||||
|
int(np.sqrt(pixel_count)),
|
||||||
|
int(np.sqrt(pixel_count)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert raw bytes to pixel array
|
||||||
|
pixels = np.frombuffer(byte_sequence, dtype=np.uint8)
|
||||||
|
pixels = pixels.reshape((*size, len(mode)))
|
||||||
|
|
||||||
|
return Image.fromarray(pixels, mode=mode)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_audio(
|
||||||
|
byte_sequence: bytes,
|
||||||
|
sample_rate: int = 44100,
|
||||||
|
channels: int = 2,
|
||||||
|
sample_width: int = 2,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Convert bytes to audio samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_sequence: Raw audio bytes
|
||||||
|
sample_rate: Audio sample rate in Hz
|
||||||
|
channels: Number of audio channels
|
||||||
|
sample_width: Bytes per sample (1, 2, or 4)
|
||||||
|
"""
|
||||||
|
# Determine format string based on sample width
|
||||||
|
format_str = {
|
||||||
|
1: "b", # signed char
|
||||||
|
2: "h", # short
|
||||||
|
4: "i", # int
|
||||||
|
}[sample_width]
|
||||||
|
|
||||||
|
# Unpack bytes to samples
|
||||||
|
sample_count = len(byte_sequence) // (channels * sample_width)
|
||||||
|
samples = struct.unpack(
|
||||||
|
f"<{sample_count * channels}{format_str}", byte_sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape to [samples, channels]
|
||||||
|
return np.array(samples).reshape(-1, channels)
|
||||||
|
|
||||||
|
def decode_data(
|
||||||
|
self,
|
||||||
|
model_output: Union[torch.Tensor, bytes],
|
||||||
|
data_type: DataType,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[str, Image.Image, np.ndarray, bytes]:
|
||||||
|
"""Main method to decode model output to desired format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output: Either tensor from model or raw bytes
|
||||||
|
data_type: Type of data to decode to
|
||||||
|
**kwargs: Additional parameters for specific decoders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded data in specified format
|
||||||
|
"""
|
||||||
|
# Convert tensor to bytes if needed
|
||||||
|
if isinstance(model_output, torch.Tensor):
|
||||||
|
byte_sequence = self.tensor_to_bytes(model_output)
|
||||||
|
else:
|
||||||
|
byte_sequence = model_output
|
||||||
|
|
||||||
|
# Decode based on type
|
||||||
|
if data_type == DataType.TEXT:
|
||||||
|
return self.decode_text(byte_sequence)
|
||||||
|
elif data_type == DataType.IMAGE:
|
||||||
|
return self.decode_image(byte_sequence, **kwargs)
|
||||||
|
elif data_type == DataType.AUDIO:
|
||||||
|
return self.decode_audio(byte_sequence, **kwargs)
|
||||||
|
elif data_type == DataType.VIDEO:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Video decoding not yet implemented"
|
||||||
|
)
|
||||||
|
else: # BINARY
|
||||||
|
return byte_sequence
|
||||||
|
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
|
||||||
|
|
||||||
|
class Modality(Enum):
|
||||||
|
TEXT = auto()
|
||||||
|
IMAGE = auto()
|
||||||
|
AUDIO = auto()
|
||||||
|
VIDEO = auto()
|
||||||
|
BINARY = auto()
|
||||||
|
MULTIMODAL = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModalityInfo:
|
||||||
|
"""Information about detected modality."""
|
||||||
|
|
||||||
|
modality: Modality
|
||||||
|
confidence: float
|
||||||
|
metadata: Dict[str, any]
|
||||||
|
sub_modalities: Optional[List["ModalityInfo"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModalityDetector:
|
||||||
|
"""Detects data modalities from byte sequences."""
|
||||||
|
|
||||||
|
# Common file signatures (magic numbers)
|
||||||
|
SIGNATURES = {
|
||||||
|
# Images
|
||||||
|
b"\xFF\xD8\xFF": "JPEG",
|
||||||
|
b"\x89PNG\r\n\x1a\n": "PNG",
|
||||||
|
b"GIF87a": "GIF",
|
||||||
|
b"GIF89a": "GIF",
|
||||||
|
b"RIFF": "WEBP",
|
||||||
|
# Audio
|
||||||
|
b"RIFF....WAVE": "WAV",
|
||||||
|
b"ID3": "MP3",
|
||||||
|
b"\xFF\xFB": "MP3",
|
||||||
|
b"OggS": "OGG",
|
||||||
|
# Video
|
||||||
|
b"\x00\x00\x00\x18ftypmp42": "MP4",
|
||||||
|
b"\x00\x00\x00\x1Cftypav01": "MP4",
|
||||||
|
b"\x1A\x45\xDF\xA3": "WEBM",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.magic = magic.Magic(mime=True)
|
||||||
|
|
||||||
|
def _check_text_probability(self, data: bytes) -> float:
|
||||||
|
"""Estimate probability that data is text."""
|
||||||
|
# Check if data is valid UTF-8
|
||||||
|
try:
|
||||||
|
data.decode("utf-8")
|
||||||
|
# Count printable ASCII characters
|
||||||
|
printable = sum(1 for b in data if 32 <= b <= 126)
|
||||||
|
return printable / len(data)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
||||||
|
"""Check if data is a valid image and extract metadata."""
|
||||||
|
try:
|
||||||
|
with io.BytesIO(data) as bio:
|
||||||
|
img = Image.open(bio)
|
||||||
|
return True, {
|
||||||
|
"format": img.format,
|
||||||
|
"size": img.size,
|
||||||
|
"mode": img.mode,
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
return False, {}
|
||||||
|
|
||||||
|
def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
||||||
|
"""Check if data is valid audio and extract metadata."""
|
||||||
|
try:
|
||||||
|
with io.BytesIO(data) as bio:
|
||||||
|
# Try to parse as WAV
|
||||||
|
with wave.open(bio) as wav:
|
||||||
|
return True, {
|
||||||
|
"channels": wav.getnchannels(),
|
||||||
|
"sample_width": wav.getsampwidth(),
|
||||||
|
"framerate": wav.getframerate(),
|
||||||
|
"frames": wav.getnframes(),
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
# Check for other audio signatures
|
||||||
|
for sig in [b"ID3", b"\xFF\xFB", b"OggS"]:
|
||||||
|
if data.startswith(sig):
|
||||||
|
return True, {"format": "compressed_audio"}
|
||||||
|
return False, {}
|
||||||
|
|
||||||
|
def _detect_boundaries(
|
||||||
|
self, data: bytes
|
||||||
|
) -> List[Tuple[int, int, Modality]]:
|
||||||
|
"""Detect boundaries between different modalities in the data."""
|
||||||
|
boundaries = []
|
||||||
|
current_pos = 0
|
||||||
|
|
||||||
|
while current_pos < len(data):
|
||||||
|
# Look for known signatures
|
||||||
|
for sig, format_type in self.SIGNATURES.items():
|
||||||
|
if data[current_pos:].startswith(sig):
|
||||||
|
# Found a signature, determine its length
|
||||||
|
if format_type in ["JPEG", "PNG", "GIF"]:
|
||||||
|
# Find image end
|
||||||
|
try:
|
||||||
|
with io.BytesIO(
|
||||||
|
data[current_pos:]
|
||||||
|
) as bio:
|
||||||
|
img = Image.open(bio)
|
||||||
|
img.verify()
|
||||||
|
size = bio.tell()
|
||||||
|
boundaries.append(
|
||||||
|
(
|
||||||
|
current_pos,
|
||||||
|
current_pos + size,
|
||||||
|
Modality.IMAGE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_pos += size
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check for text sections
|
||||||
|
text_prob = self._check_text_probability(
|
||||||
|
data[current_pos : current_pos + 1024]
|
||||||
|
)
|
||||||
|
if text_prob > 0.8:
|
||||||
|
# Look for end of text section
|
||||||
|
end_pos = current_pos + 1
|
||||||
|
while end_pos < len(data):
|
||||||
|
if (
|
||||||
|
self._check_text_probability(
|
||||||
|
data[end_pos : end_pos + 32]
|
||||||
|
)
|
||||||
|
< 0.5
|
||||||
|
):
|
||||||
|
break
|
||||||
|
end_pos += 1
|
||||||
|
boundaries.append(
|
||||||
|
(current_pos, end_pos, Modality.TEXT)
|
||||||
|
)
|
||||||
|
current_pos = end_pos
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_pos += 1
|
||||||
|
|
||||||
|
return boundaries
|
||||||
|
|
||||||
|
def detect_modality(self, data: bytes) -> ModalityInfo:
|
||||||
|
"""Detect modality of byte sequence."""
|
||||||
|
# First check for single modality
|
||||||
|
mime_type = self.magic.from_buffer(data)
|
||||||
|
|
||||||
|
# Check text
|
||||||
|
text_prob = self._check_text_probability(data)
|
||||||
|
if text_prob > 0.9:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.TEXT,
|
||||||
|
confidence=text_prob,
|
||||||
|
metadata={"mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check image
|
||||||
|
is_image, image_meta = self._check_image_validity(data)
|
||||||
|
if is_image:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
confidence=1.0,
|
||||||
|
metadata={**image_meta, "mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check audio
|
||||||
|
is_audio, audio_meta = self._check_audio_validity(data)
|
||||||
|
if is_audio:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.AUDIO,
|
||||||
|
confidence=1.0,
|
||||||
|
metadata={**audio_meta, "mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for multimodal content
|
||||||
|
boundaries = self._detect_boundaries(data)
|
||||||
|
if len(boundaries) > 1:
|
||||||
|
sub_modalities = []
|
||||||
|
for start, end, modality in boundaries:
|
||||||
|
chunk_data = data[start:end]
|
||||||
|
sub_info = self.detect_modality(chunk_data)
|
||||||
|
if sub_info.modality != Modality.BINARY:
|
||||||
|
sub_modalities.append(sub_info)
|
||||||
|
|
||||||
|
if sub_modalities:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.MULTIMODAL,
|
||||||
|
confidence=0.8,
|
||||||
|
metadata={"mime_type": "multipart/mixed"},
|
||||||
|
sub_modalities=sub_modalities,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default to binary
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.BINARY,
|
||||||
|
confidence=0.5,
|
||||||
|
metadata={"mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_modalities(
|
||||||
|
self, data: bytes
|
||||||
|
) -> List[Tuple[Modality, bytes, Dict]]:
|
||||||
|
"""Split multimodal data into separate modalities."""
|
||||||
|
boundaries = self._detect_boundaries(data)
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for start, end, modality in boundaries:
|
||||||
|
chunk = data[start:end]
|
||||||
|
info = self.detect_modality(chunk)
|
||||||
|
result.append((modality, chunk, info.metadata))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDetectBytesDecoder:
|
||||||
|
"""Decoder that automatically detects and decodes different modalities."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.detector = ModalityDetector()
|
||||||
|
self.text_decoder = ByteDetokenizer() # From previous example
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, data: bytes
|
||||||
|
) -> Union[str, Image.Image, np.ndarray, List[any]]:
|
||||||
|
"""Automatically detect and decode byte sequence."""
|
||||||
|
info = self.detector.detect_modality(data)
|
||||||
|
|
||||||
|
if info.modality == Modality.MULTIMODAL:
|
||||||
|
# Handle multimodal content
|
||||||
|
parts = self.detector.split_modalities(data)
|
||||||
|
return [
|
||||||
|
self.decode(chunk) for modality, chunk, _ in parts
|
||||||
|
]
|
||||||
|
|
||||||
|
if info.modality == Modality.TEXT:
|
||||||
|
return self.text_decoder.decode_text(data)
|
||||||
|
elif info.modality == Modality.IMAGE:
|
||||||
|
return self.text_decoder.decode_image(data)
|
||||||
|
elif info.modality == Modality.AUDIO:
|
||||||
|
return self.text_decoder.decode_audio(data)
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# def demo_auto_detection():
|
||||||
|
# """Demonstrate auto modality detection."""
|
||||||
|
# # Create mixed content
|
||||||
|
# text = "Hello, World!".encode('utf-8')
|
||||||
|
|
||||||
|
# # Create a small test image
|
||||||
|
# img = Image.new('RGB', (100, 100), color='red')
|
||||||
|
# img_bytes = io.BytesIO()
|
||||||
|
# img.save(img_bytes, format='PNG')
|
||||||
|
|
||||||
|
# # Combine into multimodal content
|
||||||
|
# mixed_content = text + img_bytes.getvalue()
|
||||||
|
|
||||||
|
# # Initialize decoder
|
||||||
|
# decoder = AutoDetectBytesDecoder()
|
||||||
|
|
||||||
|
# # Decode
|
||||||
|
# result = decoder.decode(mixed_content)
|
||||||
|
|
||||||
|
# if isinstance(result, list):
|
||||||
|
# print("Detected multimodal content:")
|
||||||
|
# for i, part in enumerate(result):
|
||||||
|
# print(f"Part {i+1}: {type(part)}")
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# demo_auto_detection()
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_data(tensor: Tensor):
|
||||||
|
byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor)
|
||||||
|
|
||||||
|
# Initialize auto-detector
|
||||||
|
decoder = AutoDetectBytesDecoder()
|
||||||
|
|
||||||
|
# Decode with automatic detection
|
||||||
|
result = decoder.decode(byte_sequence)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def demo_byte_predictor():
|
||||||
|
"""Demo with smaller dimensions to test."""
|
||||||
|
# Initialize model configuration with adjusted dimensions
|
||||||
|
config = ModelConfig(
|
||||||
|
vocab_size=256,
|
||||||
|
hidden_size=128, # Smaller for testing
|
||||||
|
num_layers=2, # Fewer layers for testing
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_query_heads=4,
|
||||||
|
dropout=0.1,
|
||||||
|
max_sequence_length=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = EnhancedBytePredictor(config)
|
||||||
|
logger.info("Model initialized")
|
||||||
|
|
||||||
|
# Move to GPU if available
|
||||||
|
device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
model = model.to(device)
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Create sample input data
|
||||||
|
batch_size = 2
|
||||||
|
seq_length = 16 # Shorter sequence for testing
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_length), device=device
|
||||||
|
)
|
||||||
|
logger.info(f"Created input tensor of shape: {input_ids.shape}")
|
||||||
|
|
||||||
|
# Test forward pass
|
||||||
|
try:
|
||||||
|
logits = model(input_ids)
|
||||||
|
logger.info(
|
||||||
|
f"Forward pass successful! Output shape: {logits.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test loss computation
|
||||||
|
target_ids = torch.randint(
|
||||||
|
0,
|
||||||
|
config.vocab_size,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
loss = model.compute_loss(input_ids, target_ids)
|
||||||
|
logger.info(
|
||||||
|
f"Loss computation successful! Loss value: {loss.item():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test generation
|
||||||
|
prompt = torch.randint(
|
||||||
|
0,
|
||||||
|
config.vocab_size,
|
||||||
|
(1, 4), # Very short prompt for testing
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
generated = model.generate(
|
||||||
|
prompt, max_new_tokens=8, temperature=0.8, top_k=50
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Generation successful! Generated shape: {generated.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during execution: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set up logging
|
||||||
|
# logger.remove() # Remove default handler
|
||||||
|
# logger.add(sys.stderr, format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
||||||
|
|
||||||
|
demo_byte_predictor()
|
@ -0,0 +1,135 @@
|
|||||||
|
# OpenAI Assistant
|
||||||
|
|
||||||
|
The OpenAI Assistant class provides a wrapper around OpenAI's Assistants API, integrating it with the swarms framework.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `OpenAIAssistant` class allows you to create and interact with OpenAI Assistants, providing a simple interface for:
|
||||||
|
|
||||||
|
- Creating assistants with specific roles and capabilities
|
||||||
|
- Adding custom functions that the assistant can call
|
||||||
|
- Managing conversation threads
|
||||||
|
- Handling tool calls and function execution
|
||||||
|
- Getting responses from the assistant
|
||||||
|
|
||||||
|
## Insstallation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install swarms
|
||||||
|
```
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
from swarms import OpenAIAssistant
|
||||||
|
|
||||||
|
#Create an assistant
|
||||||
|
assistant = OpenAIAssistant(
|
||||||
|
name="Math Tutor",
|
||||||
|
instructions="You are a helpful math tutor.",
|
||||||
|
model="gpt-4o",
|
||||||
|
tools=[{"type": "code_interpreter"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
#Run a Task
|
||||||
|
response = assistant.run("Solve the equation: 3x + 11 = 14")
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# Continue the conversation in the same thread
|
||||||
|
follow_up = assistant.run("Now explain how you solved it")
|
||||||
|
print(follow_up)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Function Calling
|
||||||
|
|
||||||
|
The assistant supports custom function integration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
def get_weather(location: str, unit: str = "celsius") -> str:
|
||||||
|
# Mock weather function
|
||||||
|
return f"The weather in {location} is 22 degrees {unit}"
|
||||||
|
|
||||||
|
# Add function to assistant
|
||||||
|
assistant.add_function(
|
||||||
|
description="Get the current weather in a location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Constructor
|
||||||
|
|
||||||
|
```python
|
||||||
|
OpenAIAssistant(
|
||||||
|
name: str,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
file_ids: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Methods
|
||||||
|
|
||||||
|
#### run(task: str) -> str
|
||||||
|
Sends a task to the assistant and returns its response. The conversation thread is maintained between calls.
|
||||||
|
|
||||||
|
#### add_function(func: Callable, description: str, parameters: Dict[str, Any]) -> None
|
||||||
|
Adds a callable function that the assistant can use during conversations.
|
||||||
|
|
||||||
|
#### add_message(content: str, file_ids: Optional[List[str]] = None) -> None
|
||||||
|
Adds a message to the current conversation thread.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The assistant implements robust error handling:
|
||||||
|
- Retries on rate limits
|
||||||
|
- Graceful handling of API errors
|
||||||
|
- Clear error messages for debugging
|
||||||
|
- Status monitoring for runs and completions
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. Thread Management
|
||||||
|
- Use the same assistant instance for related conversations
|
||||||
|
- Create new instances for unrelated tasks
|
||||||
|
- Monitor thread status during long-running operations
|
||||||
|
|
||||||
|
2. Function Integration
|
||||||
|
- Keep functions simple and focused
|
||||||
|
- Provide clear descriptions and parameter schemas
|
||||||
|
- Handle errors gracefully in custom functions
|
||||||
|
- Test functions independently before integration
|
||||||
|
|
||||||
|
3. Performance
|
||||||
|
- Reuse assistant instances when possible
|
||||||
|
- Monitor and handle rate limits appropriately
|
||||||
|
- Use appropriate polling intervals for status checks
|
||||||
|
- Consider implementing timeouts for long-running operations
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [OpenAI Assistants API Documentation](https://platform.openai.com/docs/assistants/overview)
|
||||||
|
- [OpenAI Function Calling Guide](https://platform.openai.com/docs/guides/function-calling)
|
||||||
|
- [OpenAI Rate Limits](https://platform.openai.com/docs/guides/rate-limits)
|
||||||
|
|
||||||
|
|
||||||
|
|
Can't render this file because it has a wrong number of fields in line 4.
|
@ -0,0 +1,188 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||||
|
from swarm_models.openai_function_caller import OpenAIFunctionCaller
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicParser:
|
||||||
|
@staticmethod
|
||||||
|
def extract_fields(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
field_name: (field.annotation, ... if field.is_required() else None)
|
||||||
|
for field_name, field in model.model_fields.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_partial_model(model: Type[BaseModel], data: Dict[str, Any]) -> Type[BaseModel]:
|
||||||
|
fields = {
|
||||||
|
field_name: (field.annotation, ... if field.is_required() else None)
|
||||||
|
for field_name, field in model.model_fields.items()
|
||||||
|
if field_name in data
|
||||||
|
}
|
||||||
|
return create_model(f"Partial{model.__name__}", **fields)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, data: Union[str, Dict[str, Any]], model: Type[BaseModel]) -> Optional[BaseModel]:
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try full model first
|
||||||
|
with suppress(ValidationError):
|
||||||
|
return model.model_validate(data)
|
||||||
|
|
||||||
|
# Create and try partial model
|
||||||
|
partial_model = cls.create_partial_model(model, data)
|
||||||
|
with suppress(ValidationError):
|
||||||
|
return partial_model.model_validate(data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Define the Thoughts schema
|
||||||
|
class Thoughts(BaseModel):
|
||||||
|
text: str = Field(..., description="Current thoughts or observations regarding the task.")
|
||||||
|
reasoning: str = Field(..., description="Logical reasoning behind the thought process.")
|
||||||
|
plan: str = Field(..., description="A short bulleted list that conveys the immediate and long-term plan.")
|
||||||
|
criticism: str = Field(..., description="Constructive self-criticism to improve future responses.")
|
||||||
|
speak: str = Field(..., description="A concise summary of thoughts intended for the user.")
|
||||||
|
|
||||||
|
# Define the Command schema
|
||||||
|
class Command(BaseModel):
|
||||||
|
name: str = Field(..., description="Command name to execute from the provided list of commands.")
|
||||||
|
args: Dict[str, Any] = Field(..., description="Arguments required to execute the command.")
|
||||||
|
|
||||||
|
# Define the AgentResponse schema
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: Thoughts = Field(..., description="The agent's current thoughts and reasoning.")
|
||||||
|
command: Command = Field(..., description="The command to execute along with its arguments.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Define tool functions
|
||||||
|
def fluid_api_command(task: str):
|
||||||
|
"""Execute a fluid API request."""
|
||||||
|
# response = fluid_api_request(task)
|
||||||
|
print(response.model_dump_json(indent=4))
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def send_tweet_command(text: str):
|
||||||
|
"""Simulate sending a tweet."""
|
||||||
|
print(f"Tweet sent: {text}")
|
||||||
|
return {"status": "success", "message": f"Tweet sent: {text}"}
|
||||||
|
|
||||||
|
|
||||||
|
def do_nothing_command():
|
||||||
|
"""Do nothing."""
|
||||||
|
print("Doing nothing...")
|
||||||
|
return {"status": "success", "message": "No action taken."}
|
||||||
|
|
||||||
|
|
||||||
|
def task_complete_command(reason: str):
|
||||||
|
"""Mark the task as complete and provide a reason."""
|
||||||
|
print(f"Task completed: {reason}")
|
||||||
|
return {"status": "success", "message": f"Task completed: {reason}"}
|
||||||
|
|
||||||
|
|
||||||
|
# Dynamic command execution
|
||||||
|
def execute_command(name: str, args: Dict[str, Any]):
|
||||||
|
"""Dynamically execute a command based on its name and arguments."""
|
||||||
|
command_map: Dict[str, Callable] = {
|
||||||
|
"fluid_api": lambda **kwargs: fluid_api_command(task=kwargs.get("task")),
|
||||||
|
"send_tweet": lambda **kwargs: send_tweet_command(text=kwargs.get("text")),
|
||||||
|
"do_nothing": lambda **kwargs: do_nothing_command(),
|
||||||
|
"task_complete": lambda **kwargs: task_complete_command(reason=kwargs.get("reason")),
|
||||||
|
}
|
||||||
|
|
||||||
|
if name not in command_map:
|
||||||
|
raise ValueError(f"Unknown command: {name}")
|
||||||
|
|
||||||
|
# Execute the command with the provided arguments
|
||||||
|
return command_map[name](**args)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_execute_command(response: Union[str, Dict[str, Any]], base_model: Type[BaseModel] = AgentResponse) -> Any:
|
||||||
|
"""Enhanced command parser with flexible input handling"""
|
||||||
|
parsed = DynamicParser.parse(response, base_model)
|
||||||
|
if not parsed:
|
||||||
|
raise ValueError("Failed to parse response")
|
||||||
|
|
||||||
|
if hasattr(parsed, 'command'):
|
||||||
|
command_name = parsed.command.name
|
||||||
|
command_args = parsed.command.args
|
||||||
|
return execute_command(command_name, command_args)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
ainame = "AutoAgent"
|
||||||
|
userprovided = "assistant"
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = f"""
|
||||||
|
You are {ainame}, an advanced and autonomous {userprovided}.
|
||||||
|
Your role is to make decisions and complete tasks independently without seeking user assistance. Leverage your strengths as an LLM to solve tasks efficiently, adhering strictly to the commands and resources provided.
|
||||||
|
|
||||||
|
### GOALS:
|
||||||
|
1. {userprovided}
|
||||||
|
2. Execute tasks with precision and efficiency.
|
||||||
|
3. Ensure outputs are actionable and aligned with the user's objectives.
|
||||||
|
4. Continuously optimize task strategies for maximum effectiveness.
|
||||||
|
5. Maintain reliability and consistency in all responses.
|
||||||
|
|
||||||
|
### CONSTRAINTS:
|
||||||
|
1. Memory limit: ~4000 words for short-term memory. Save essential information to files immediately to avoid loss.
|
||||||
|
2. Independent decision-making: Do not rely on user assistance.
|
||||||
|
3. Exclusively use commands in double quotes (e.g., "command name").
|
||||||
|
4. Use subprocesses for commands that may take longer than a few minutes.
|
||||||
|
5. Ensure all outputs strictly adhere to the specified JSON response format.
|
||||||
|
|
||||||
|
### COMMANDS:
|
||||||
|
1. Fluid API: "fluid_api", args: "method": "<GET/POST/...>", "url": "<url>", "headers": "<headers>", "body": "<payload>"
|
||||||
|
18. Send Tweet: "send_tweet", args: "text": "<text>"
|
||||||
|
19. Do Nothing: "do_nothing", args:
|
||||||
|
20. Task Complete (Shutdown): "task_complete", args: "reason": "<reason>"
|
||||||
|
|
||||||
|
### RESOURCES:
|
||||||
|
1. Internet access for real-time information and data gathering.
|
||||||
|
2. Long-term memory management for storing critical information.
|
||||||
|
3. Access to GPT-3.5-powered Agents for delegating tasks.
|
||||||
|
4. File handling capabilities for output storage and retrieval.
|
||||||
|
|
||||||
|
### PERFORMANCE EVALUATION:
|
||||||
|
1. Continuously analyze and reflect on actions to ensure optimal task completion.
|
||||||
|
2. Self-critique decisions and strategies constructively to identify areas for improvement.
|
||||||
|
3. Ensure every command serves a clear purpose and minimizes resource usage.
|
||||||
|
4. Complete tasks in the least number of steps, balancing speed and accuracy.
|
||||||
|
|
||||||
|
### RESPONSE FORMAT:
|
||||||
|
Always respond in a strict JSON format as described below. Ensure your responses can be parsed with Python's `json.loads`:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize the OpenAIFunctionCaller
|
||||||
|
model = OpenAIFunctionCaller(
|
||||||
|
system_prompt=SYSTEM_PROMPT,
|
||||||
|
max_tokens=4000,
|
||||||
|
temperature=0.9,
|
||||||
|
base_model=AgentResponse, # Pass the Pydantic schema as the base model
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
openai_api_key=os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
user_input = (
|
||||||
|
"Analyze the provided Python code for inefficiencies, generate suggestions for improvements, "
|
||||||
|
"and provide optimized code."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = model.run(user_input)
|
||||||
|
response = parse_and_execute_command(response)
|
||||||
|
print(response)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
|||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
Agent(
|
||||||
|
agent_name="Stock-Analysis-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
).run("What are 5 hft algorithms")
|
@ -0,0 +1,253 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
from swarms.agents.create_agents_from_yaml import (
|
||||||
|
create_agents_from_yaml,
|
||||||
|
)
|
||||||
|
from swarms.utils.formatter import formatter
|
||||||
|
from swarms.utils.litellm import LiteLLM
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_yaml_for_parsing(raw_yaml: str) -> str:
|
||||||
|
"""
|
||||||
|
Prepares raw YAML content by fixing spacing and formatting issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_yaml (str): The raw YAML content extracted from Markdown.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The cleaned YAML content ready for parsing.
|
||||||
|
"""
|
||||||
|
# Fix sequence items that are improperly placed on the same line as their key
|
||||||
|
fixed_yaml = re.sub(
|
||||||
|
r"(\b\w+\b):\s*-\s*", r"\1:\n - ", raw_yaml
|
||||||
|
) # Fix "key: - value" to "key:\n - value"
|
||||||
|
|
||||||
|
# Ensure proper spacing after colons
|
||||||
|
fixed_yaml = re.sub(
|
||||||
|
r"(\S):(\S)", r"\1: \2", fixed_yaml
|
||||||
|
) # Ensure space after colons
|
||||||
|
|
||||||
|
# Remove trailing spaces before newlines
|
||||||
|
fixed_yaml = re.sub(r"\s+\n", "\n", fixed_yaml)
|
||||||
|
|
||||||
|
# Replace non-breaking spaces (if any) with regular spaces
|
||||||
|
fixed_yaml = fixed_yaml.replace("\xa0", " ")
|
||||||
|
|
||||||
|
return fixed_yaml.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml_from_swarm_markdown(markdown_text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Extracts and prepares YAML content from a Markdown-style 'Auto-Swarm-Builder' block and parses it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
markdown_text (str): The Markdown text containing the YAML inside 'Auto-Swarm-Builder' block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A parsed Python dictionary of the YAML content.
|
||||||
|
"""
|
||||||
|
# Match the 'Auto-Swarm-Builder' block with YAML inside triple backticks
|
||||||
|
pattern = r"```yaml\s*\n(.*?)```"
|
||||||
|
match = re.search(pattern, markdown_text, re.DOTALL)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
raise ValueError(
|
||||||
|
"No YAML content found in the 'Auto-Swarm-Builder' block."
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_yaml = match.group(1).strip()
|
||||||
|
|
||||||
|
# Preprocess and normalize the YAML content
|
||||||
|
normalized_yaml = prepare_yaml_for_parsing(raw_yaml)
|
||||||
|
|
||||||
|
return normalized_yaml
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_GEN_PROMPT = """
|
||||||
|
You are a specialized agent responsible for creating YAML configuration files for multi-agent swarms. Your role is to generate well-structured YAML that defines both individual agents and swarm architectures based on user requirements.
|
||||||
|
Output only the yaml nothing else. You will be penalized for making mistakes
|
||||||
|
|
||||||
|
GUIDELINES:
|
||||||
|
1. Each YAML file must contain an `agents` section with at least one agent configuration
|
||||||
|
2. Each agent configuration requires the following mandatory fields:
|
||||||
|
- agent_name (string)
|
||||||
|
- system_prompt (string)
|
||||||
|
|
||||||
|
3. Optional agent fields include:
|
||||||
|
- max_loops (integer)
|
||||||
|
- autosave (boolean)
|
||||||
|
- dashboard (boolean)
|
||||||
|
- verbose (boolean)
|
||||||
|
- dynamic_temperature_enabled (boolean)
|
||||||
|
- saved_state_path (string)
|
||||||
|
- user_name (string)
|
||||||
|
- retry_attempts (integer)
|
||||||
|
- context_length (integer)
|
||||||
|
- return_step_meta (boolean)
|
||||||
|
- output_type (string)
|
||||||
|
- task (string)
|
||||||
|
|
||||||
|
4. When a swarm is needed, include a `swarm_architecture` section with:
|
||||||
|
Mandatory fields:
|
||||||
|
- name (string)
|
||||||
|
- swarm_type (string: "ConcurrentWorkflow" or "SequentialWorkflow") [AgentRearrange, MixtureOfAgents, SpreadSheetSwarm, SequentialWorkflow, ConcurrentWorkflow]
|
||||||
|
|
||||||
|
Optional fields:
|
||||||
|
- description (string)
|
||||||
|
- max_loops (integer)
|
||||||
|
- task (string)
|
||||||
|
|
||||||
|
TEMPLATE STRUCTURE:
|
||||||
|
```yaml
|
||||||
|
agents:
|
||||||
|
- agent_name: "Agent-1-Name"
|
||||||
|
system_prompt: "Detailed system prompt here"
|
||||||
|
max_loops: 1
|
||||||
|
# [additional optional fields]
|
||||||
|
|
||||||
|
- agent_name: "Agent-2-Name"
|
||||||
|
system_prompt: "Detailed system prompt here"
|
||||||
|
# [additional optional fields]
|
||||||
|
|
||||||
|
swarm_architecture:
|
||||||
|
name: "Swarm-Name"
|
||||||
|
description: "Swarm purpose and goals"
|
||||||
|
swarm_type: "ConcurrentWorkflow"
|
||||||
|
max_loops: 5
|
||||||
|
task: "Main swarm task description"
|
||||||
|
```
|
||||||
|
|
||||||
|
VALIDATION RULES:
|
||||||
|
1. All agent names must be unique
|
||||||
|
2. System prompts must be clear and specific to the agent's role
|
||||||
|
3. Integer values must be positive
|
||||||
|
4. Boolean values must be true or false (lowercase)
|
||||||
|
5. File paths should use forward slashes
|
||||||
|
6. Tasks should be specific and aligned with the agent/swarm purpose
|
||||||
|
|
||||||
|
When generating a YAML configuration:
|
||||||
|
1. Ask for specific requirements about the agents and swarm needed
|
||||||
|
2. Determine if a swarm architecture is necessary based on the task complexity
|
||||||
|
3. Generate appropriate system prompts for each agent based on their roles
|
||||||
|
4. Include relevant optional fields based on the use case
|
||||||
|
5. Validate the configuration against all rules before returning
|
||||||
|
|
||||||
|
Example valid YAML configurations are provided below. Use these as references for structure and formatting:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
|
||||||
|
|
||||||
|
agents:
|
||||||
|
- agent_name: "Data-Analysis-Agent"
|
||||||
|
system_prompt: "You are a specialized data analysis agent focused on processing and interpreting financial data. Provide clear, actionable insights based on the data provided."
|
||||||
|
max_loops: 3
|
||||||
|
autosave: true
|
||||||
|
verbose: true
|
||||||
|
context_length: 100000
|
||||||
|
output_type: "json"
|
||||||
|
task: "Analyze quarterly financial reports and identify trends"
|
||||||
|
|
||||||
|
# Multi-Agent Swarm Example
|
||||||
|
agents:
|
||||||
|
- agent_name: "Research-Agent"
|
||||||
|
system_prompt: "You are a research agent specialized in gathering and summarizing scientific publications. Focus on peer-reviewed sources and provide comprehensive summaries."
|
||||||
|
max_loops: 2
|
||||||
|
context_length: 150000
|
||||||
|
output_type: "str"
|
||||||
|
|
||||||
|
- agent_name: "Analysis-Agent"
|
||||||
|
system_prompt: "You are an analysis agent that processes research summaries and identifies key patterns and insights. Provide detailed analytical reports."
|
||||||
|
max_loops: 3
|
||||||
|
context_length: 200000
|
||||||
|
output_type: "json"
|
||||||
|
|
||||||
|
swarm_architecture:
|
||||||
|
name: "Research-Analysis-Swarm"
|
||||||
|
description: "A swarm for comprehensive research analysis and insight generation"
|
||||||
|
swarm_type: "SequentialWorkflow"
|
||||||
|
max_loops: 5
|
||||||
|
task: "Research and analyze recent developments in quantum computing"
|
||||||
|
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_swarm_config(
|
||||||
|
task: str,
|
||||||
|
file_name: str = "swarm_config_output.yaml",
|
||||||
|
model_name: str = "gpt-4o",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates a swarm configuration based on the provided task and model name.
|
||||||
|
|
||||||
|
This function attempts to generate a swarm configuration by running an agent with the specified task and model name.
|
||||||
|
It then parses the output into YAML format and creates agents based on the parsed YAML content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be performed by the swarm.
|
||||||
|
file_name (str, optional): The file name for the output YAML configuration. Defaults to "swarm_config_output.yaml".
|
||||||
|
model_name (str, optional): The name of the model to use for the agent. Defaults to "gpt-4o".
|
||||||
|
*args: Additional positional arguments to be passed to the agent's run method.
|
||||||
|
**kwargs: Additional keyword arguments to be passed to the agent's run method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The output of the swarm configuration generation process. This can be a SwarmRouter instance or an error message.
|
||||||
|
"""
|
||||||
|
formatter.print_panel(
|
||||||
|
"Auto Generating Swarm...", "Auto Swarm Builder"
|
||||||
|
)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(min=4, max=10),
|
||||||
|
)
|
||||||
|
def attempt_generate_swarm_config():
|
||||||
|
try:
|
||||||
|
model = LiteLLM(model_name=model_name)
|
||||||
|
|
||||||
|
# Initialize the agent
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Auto-Swarm-Builder",
|
||||||
|
system_prompt=AUTO_GEN_PROMPT,
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
saved_state_path="swarm_builder.json",
|
||||||
|
user_name="swarms_corp",
|
||||||
|
output_type="str",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate output from the agent
|
||||||
|
raw_output = agent.run(task, *args, **kwargs)
|
||||||
|
yaml_content = parse_yaml_from_swarm_markdown(raw_output)
|
||||||
|
print(yaml_content)
|
||||||
|
|
||||||
|
# Create agents from the YAML file
|
||||||
|
output = create_agents_from_yaml(
|
||||||
|
yaml_string=yaml_content,
|
||||||
|
return_type="run_swarm",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter.print_panel(
|
||||||
|
"Swarm configuration generated successfully.",
|
||||||
|
"Success",
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
formatter.print_panel(
|
||||||
|
f"Error generating swarm configuration: {str(e)}",
|
||||||
|
"Error",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return attempt_generate_swarm_config()
|
@ -0,0 +1,264 @@
|
|||||||
|
from typing import Optional, List, Dict, Any, Callable
|
||||||
|
import time
|
||||||
|
from openai import OpenAI
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
import json
|
||||||
|
|
||||||
|
class OpenAIAssistant(Agent):
|
||||||
|
"""
|
||||||
|
OpenAI Assistant wrapper for the swarms framework.
|
||||||
|
Integrates OpenAI's Assistants API with the swarms architecture.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> assistant = OpenAIAssistant(
|
||||||
|
... name="Math Tutor",
|
||||||
|
... instructions="You are a personal math tutor.",
|
||||||
|
... model="gpt-4o",
|
||||||
|
... tools=[{"type": "code_interpreter"}]
|
||||||
|
... )
|
||||||
|
>>> response = assistant.run("Solve 3x + 11 = 14")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
file_ids: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Initialize the OpenAI Assistant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the assistant
|
||||||
|
instructions: System instructions for the assistant
|
||||||
|
model: Model to use (default: gpt-4-turbo-preview)
|
||||||
|
tools: List of tools to enable (code_interpreter, retrieval)
|
||||||
|
file_ids: List of file IDs to attach
|
||||||
|
metadata: Additional metadata
|
||||||
|
functions: List of custom functions to make available
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Initialize tools list with any provided functions
|
||||||
|
self.tools = tools or []
|
||||||
|
if functions:
|
||||||
|
for func in functions:
|
||||||
|
self.tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": func
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create the OpenAI Assistant
|
||||||
|
self.client = OpenAI()
|
||||||
|
self.assistant = self.client.beta.assistants.create(
|
||||||
|
name=name,
|
||||||
|
instructions=instructions,
|
||||||
|
model=model,
|
||||||
|
tools=self.tools,
|
||||||
|
file_ids=file_ids or [],
|
||||||
|
metadata=metadata or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store available functions
|
||||||
|
self.available_functions: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
def add_function(self, func: Callable, description: str, parameters: Dict[str, Any]) -> None:
|
||||||
|
"""Add a function that the assistant can call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to make available to the assistant
|
||||||
|
description: Description of what the function does
|
||||||
|
parameters: JSON schema describing the function parameters
|
||||||
|
"""
|
||||||
|
func_dict = {
|
||||||
|
"name": func.__name__,
|
||||||
|
"description": description,
|
||||||
|
"parameters": parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add to tools list
|
||||||
|
self.tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": func_dict
|
||||||
|
})
|
||||||
|
|
||||||
|
# Store function reference
|
||||||
|
self.available_functions[func.__name__] = func
|
||||||
|
|
||||||
|
# Update assistant with new tools
|
||||||
|
self.assistant = self.client.beta.assistants.update(
|
||||||
|
assistant_id=self.assistant.id,
|
||||||
|
tools=self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_tool_calls(self, run, thread_id: str) -> None:
|
||||||
|
"""Handle any required tool calls during a run.
|
||||||
|
|
||||||
|
This method processes any tool calls required by the assistant during execution.
|
||||||
|
It extracts function calls, executes them with provided arguments, and submits
|
||||||
|
the results back to the assistant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run: The current run object from the OpenAI API
|
||||||
|
thread_id: ID of the current conversation thread
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated run object after processing tool calls
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If there are errors executing the tool calls
|
||||||
|
"""
|
||||||
|
while run.status == "requires_action":
|
||||||
|
tool_calls = run.required_action.submit_tool_outputs.tool_calls
|
||||||
|
tool_outputs = []
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if tool_call.type == "function":
|
||||||
|
# Get function details
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
|
||||||
|
# Call function if available
|
||||||
|
if function_name in self.available_functions:
|
||||||
|
function_response = self.available_functions[function_name](**function_args)
|
||||||
|
tool_outputs.append({
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"output": str(function_response)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Submit outputs back to the run
|
||||||
|
run = self.client.beta.threads.runs.submit_tool_outputs(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run.id,
|
||||||
|
tool_outputs=tool_outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for processing
|
||||||
|
run = self._wait_for_run(run)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
def _wait_for_run(self, run) -> Any:
|
||||||
|
"""Wait for a run to complete and handle any required actions.
|
||||||
|
|
||||||
|
This method polls the OpenAI API to check the status of a run until it completes
|
||||||
|
or fails. It handles intermediate states like required actions and implements
|
||||||
|
exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run: The run object to monitor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The completed run object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the run fails or expires
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
run = self.client.beta.threads.runs.retrieve(
|
||||||
|
thread_id=run.thread_id,
|
||||||
|
run_id=run.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if run.status == "completed":
|
||||||
|
break
|
||||||
|
elif run.status == "requires_action":
|
||||||
|
run = self._handle_tool_calls(run, run.thread_id)
|
||||||
|
if run.status == "completed":
|
||||||
|
break
|
||||||
|
elif run.status in ["failed", "expired"]:
|
||||||
|
raise Exception(f"Run failed with status: {run.status}")
|
||||||
|
|
||||||
|
time.sleep(3) # Wait 3 seconds before checking again
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
def _ensure_thread(self):
|
||||||
|
"""Ensure a thread exists for the conversation.
|
||||||
|
|
||||||
|
This method checks if there is an active thread for the current conversation.
|
||||||
|
If no thread exists, it creates a new one. This maintains conversation context
|
||||||
|
across multiple interactions.
|
||||||
|
|
||||||
|
Side Effects:
|
||||||
|
Sets self.thread if it doesn't exist
|
||||||
|
"""
|
||||||
|
if not self.thread:
|
||||||
|
self.thread = self.client.beta.threads.create()
|
||||||
|
|
||||||
|
def add_message(self, content: str, file_ids: Optional[List[str]] = None) -> None:
|
||||||
|
"""Add a message to the thread.
|
||||||
|
|
||||||
|
This method adds a new user message to the conversation thread. It ensures
|
||||||
|
a thread exists before adding the message and handles file attachments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The text content of the message to add
|
||||||
|
file_ids: Optional list of file IDs to attach to the message. These must be
|
||||||
|
files that have been previously uploaded to OpenAI.
|
||||||
|
|
||||||
|
Side Effects:
|
||||||
|
Creates a new thread if none exists
|
||||||
|
Adds the message to the thread in OpenAI's system
|
||||||
|
"""
|
||||||
|
self._ensure_thread()
|
||||||
|
self.client.beta.threads.messages.create(
|
||||||
|
thread_id=self.thread.id,
|
||||||
|
role="user",
|
||||||
|
content=content,
|
||||||
|
file_ids=file_ids or []
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_response(self) -> str:
|
||||||
|
"""Get the latest assistant response from the thread."""
|
||||||
|
messages = self.client.beta.threads.messages.list(
|
||||||
|
thread_id=self.thread.id,
|
||||||
|
order="desc",
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages.data:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
message = messages.data[0]
|
||||||
|
if message.role == "assistant":
|
||||||
|
return message.content[0].text.value
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""Run a task using the OpenAI Assistant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task or prompt to send to the assistant
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The assistant's response as a string
|
||||||
|
"""
|
||||||
|
self._ensure_thread()
|
||||||
|
|
||||||
|
# Add the user message
|
||||||
|
self.add_message(task)
|
||||||
|
|
||||||
|
# Create and run the assistant
|
||||||
|
run = self.client.beta.threads.runs.create(
|
||||||
|
thread_id=self.thread.id,
|
||||||
|
assistant_id=self.assistant.id,
|
||||||
|
instructions=self.instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
run = self._wait_for_run(run)
|
||||||
|
|
||||||
|
# Only get and return the response if run completed successfully
|
||||||
|
if run.status == "completed":
|
||||||
|
return self._get_response()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def call(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""Alias for run() to maintain compatibility with different agent interfaces."""
|
||||||
|
return self.run(task, *args, **kwargs)
|
@ -1,9 +1,5 @@
|
|||||||
from swarms.artifacts.base_artifact import BaseArtifact
|
|
||||||
from swarms.artifacts.text_artifact import TextArtifact
|
|
||||||
from swarms.artifacts.main_artifact import Artifact
|
from swarms.artifacts.main_artifact import Artifact
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseArtifact",
|
|
||||||
"TextArtifact",
|
|
||||||
"Artifact",
|
"Artifact",
|
||||||
]
|
]
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseArtifact(ABC):
|
|
||||||
"""
|
|
||||||
Base class for artifacts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.id is None:
|
|
||||||
self.id = uuid.uuid4().hex
|
|
||||||
if self.name is None:
|
|
||||||
self.name = self.id
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_to_bytes(cls, value: Any) -> bytes:
|
|
||||||
"""
|
|
||||||
Convert the value to bytes.
|
|
||||||
"""
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
return str(value).encode()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_to_dict(cls, value: Any) -> dict:
|
|
||||||
"""
|
|
||||||
Convert the value to a dictionary.
|
|
||||||
"""
|
|
||||||
if isinstance(value, dict):
|
|
||||||
dict_value = value
|
|
||||||
else:
|
|
||||||
dict_value = json.loads(value)
|
|
||||||
|
|
||||||
return {k: v for k, v in dict_value.items()}
|
|
||||||
|
|
||||||
def to_text(self) -> str:
|
|
||||||
"""
|
|
||||||
Convert the value to text.
|
|
||||||
"""
|
|
||||||
return str(self.value)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""
|
|
||||||
Return a string representation of the artifact.
|
|
||||||
"""
|
|
||||||
return self.to_text()
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
"""
|
|
||||||
Return the boolean value of the artifact.
|
|
||||||
"""
|
|
||||||
return bool(self.value)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""
|
|
||||||
Return the length of the artifact.
|
|
||||||
"""
|
|
||||||
return len(self.value)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __add__(self, other: BaseArtifact) -> BaseArtifact:
|
|
||||||
"""
|
|
||||||
Add two artifacts together.
|
|
||||||
"""
|
|
||||||
...
|
|
@ -1,58 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable
|
|
||||||
from swarms.artifacts.base_artifact import BaseArtifact
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextArtifact(BaseArtifact):
|
|
||||||
"""
|
|
||||||
Represents a text artifact.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
value (str): The text value of the artifact.
|
|
||||||
encoding (str, optional): The encoding of the text (default is "utf-8").
|
|
||||||
encoding_error_handler (str, optional): The error handler for encoding errors (default is "strict").
|
|
||||||
_embedding (list[float]): The embedding of the text artifact (default is an empty list).
|
|
||||||
|
|
||||||
Properties:
|
|
||||||
embedding (Optional[list[float]]): The embedding of the text artifact.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
__add__(self, other: BaseArtifact) -> TextArtifact: Concatenates the text value of the artifact with another artifact.
|
|
||||||
__bool__(self) -> bool: Checks if the text value of the artifact is non-empty.
|
|
||||||
generate_embedding(self, driver: BaseEmbeddingModel) -> Optional[list[float]]: Generates the embedding of the text artifact using a given embedding model.
|
|
||||||
token_count(self, tokenizer: BaseTokenizer) -> int: Counts the number of tokens in the text artifact using a given tokenizer.
|
|
||||||
to_bytes(self) -> bytes: Converts the text value of the artifact to bytes using the specified encoding and error handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
value: str
|
|
||||||
encoding: str = "utf-8"
|
|
||||||
encoding_error_handler: str = "strict"
|
|
||||||
tokenizer: Callable = None
|
|
||||||
_embedding: list[float] = field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding(self) -> list[float] | None:
|
|
||||||
return None if len(self._embedding) == 0 else self._embedding
|
|
||||||
|
|
||||||
def __add__(self, other: BaseArtifact) -> TextArtifact:
|
|
||||||
return TextArtifact(self.value + other.value)
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return bool(self.value.strip())
|
|
||||||
|
|
||||||
def generate_embedding(self, model) -> list[float] | None:
|
|
||||||
self._embedding.clear()
|
|
||||||
self._embedding.extend(model.embed_string(str(self.value)))
|
|
||||||
|
|
||||||
return self.embedding
|
|
||||||
|
|
||||||
def token_count(self) -> int:
|
|
||||||
return self.tokenizer.count_tokens(str(self.value))
|
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
return self.value.encode(
|
|
||||||
encoding=self.encoding, errors=self.encoding_error_handler
|
|
||||||
)
|
|
@ -0,0 +1,62 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
from swarms.structs.base_workflow import BaseWorkflow
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.utils.loguru_logger import logger
|
||||||
|
|
||||||
|
class AsyncWorkflow(BaseWorkflow):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "AsyncWorkflow",
|
||||||
|
agents: List[Agent] = None,
|
||||||
|
max_workers: int = 5,
|
||||||
|
dashboard: bool = False,
|
||||||
|
autosave: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(agents=agents, **kwargs)
|
||||||
|
self.name = name
|
||||||
|
self.agents = agents or []
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.dashboard = dashboard
|
||||||
|
self.autosave = autosave
|
||||||
|
self.verbose = verbose
|
||||||
|
self.task_pool = []
|
||||||
|
self.results = []
|
||||||
|
self.loop = None
|
||||||
|
|
||||||
|
async def _execute_agent_task(self, agent: Agent, task: str) -> Any:
|
||||||
|
"""Execute a single agent task asynchronously"""
|
||||||
|
try:
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(f"Agent {agent.agent_name} processing task: {task}")
|
||||||
|
result = await agent.arun(task)
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(f"Agent {agent.agent_name} completed task")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in agent {agent.agent_name}: {str(e)}")
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
async def run(self, task: str) -> List[Any]:
|
||||||
|
"""Run the workflow with all agents processing the task concurrently"""
|
||||||
|
if not self.agents:
|
||||||
|
raise ValueError("No agents provided to the workflow")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create tasks for all agents
|
||||||
|
tasks = [self._execute_agent_task(agent, task) for agent in self.agents]
|
||||||
|
|
||||||
|
# Execute all tasks concurrently
|
||||||
|
self.results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
if self.autosave:
|
||||||
|
# TODO: Implement autosave logic here
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in workflow execution: {str(e)}")
|
||||||
|
raise
|
@ -1,5 +1,3 @@
|
|||||||
from loguru import logger
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
@ -0,0 +1,665 @@
|
|||||||
|
"""
|
||||||
|
GraphSwarm: A production-grade framework for orchestrating swarms of agents
|
||||||
|
Author: Claude
|
||||||
|
License: MIT
|
||||||
|
Version: 2.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import networkx as nx
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logger.add(
|
||||||
|
"graphswarm.log",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutput(BaseModel):
|
||||||
|
"""Structured output from an agent."""
|
||||||
|
|
||||||
|
agent_name: str
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
output: Any
|
||||||
|
execution_time: float
|
||||||
|
error: Optional[str] = None
|
||||||
|
metadata: Dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmOutput(BaseModel):
|
||||||
|
"""Structured output from the entire swarm."""
|
||||||
|
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
outputs: Dict[str, AgentOutput]
|
||||||
|
execution_time: float
|
||||||
|
success: bool
|
||||||
|
error: Optional[str] = None
|
||||||
|
metadata: Dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmMemory:
|
||||||
|
"""Vector-based memory system for GraphSwarm using ChromaDB."""
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str = "swarm_memories"):
|
||||||
|
"""Initialize SwarmMemory with ChromaDB."""
|
||||||
|
self.client = chromadb.Client()
|
||||||
|
|
||||||
|
# Get or create collection
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name=collection_name,
|
||||||
|
metadata={"description": "GraphSwarm execution memories"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def store_execution(self, task: str, result: SwarmOutput):
|
||||||
|
"""Store execution results in vector memory."""
|
||||||
|
try:
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"success": result.success,
|
||||||
|
"execution_time": result.execution_time,
|
||||||
|
"agent_sequence": json.dumps(
|
||||||
|
[name for name in result.outputs.keys()]
|
||||||
|
),
|
||||||
|
"error": result.error if result.error else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create document from outputs
|
||||||
|
document = {
|
||||||
|
"task": task,
|
||||||
|
"outputs": json.dumps(
|
||||||
|
{
|
||||||
|
name: {
|
||||||
|
"output": str(output.output),
|
||||||
|
"execution_time": output.execution_time,
|
||||||
|
"error": output.error,
|
||||||
|
}
|
||||||
|
for name, output in result.outputs.items()
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store in ChromaDB
|
||||||
|
self.collection.add(
|
||||||
|
documents=[json.dumps(document)],
|
||||||
|
metadatas=[metadata],
|
||||||
|
ids=[f"exec_{datetime.now().timestamp()}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("added to database")
|
||||||
|
|
||||||
|
logger.info(f"Stored execution in memory: {task}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to store execution in memory: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_similar_executions(self, task: str, limit: int = 5):
|
||||||
|
"""Retrieve similar past executions."""
|
||||||
|
try:
|
||||||
|
# Query ChromaDB for similar executions
|
||||||
|
results = self.collection.query(
|
||||||
|
query_texts=[task],
|
||||||
|
n_results=limit,
|
||||||
|
include=["documents", "metadatas"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if not results["documents"]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
executions = []
|
||||||
|
for doc, metadata in zip(
|
||||||
|
results["documents"][0], results["metadatas"][0]
|
||||||
|
):
|
||||||
|
doc_dict = json.loads(doc)
|
||||||
|
executions.append(
|
||||||
|
{
|
||||||
|
"task": doc_dict["task"],
|
||||||
|
"outputs": json.loads(doc_dict["outputs"]),
|
||||||
|
"success": metadata["success"],
|
||||||
|
"execution_time": metadata["execution_time"],
|
||||||
|
"agent_sequence": json.loads(
|
||||||
|
metadata["agent_sequence"]
|
||||||
|
),
|
||||||
|
"timestamp": metadata["timestamp"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return executions
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to retrieve similar executions: {str(e)}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_optimal_sequence(self, task: str) -> Optional[List[str]]:
|
||||||
|
"""Get the most successful agent sequence for similar tasks."""
|
||||||
|
similar_executions = self.get_similar_executions(task)
|
||||||
|
print(f"similar_executions {similar_executions}")
|
||||||
|
|
||||||
|
if not similar_executions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by success and execution time
|
||||||
|
successful_execs = [
|
||||||
|
ex for ex in similar_executions if ex["success"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if not successful_execs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return sequence from most successful execution
|
||||||
|
return successful_execs[0]["agent_sequence"]
|
||||||
|
|
||||||
|
def clear_memory(self):
|
||||||
|
"""Clear all memories."""
|
||||||
|
self.client.delete_collection(self.collection.name)
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name=self.collection.name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphSwarm:
|
||||||
|
"""
|
||||||
|
Enhanced framework for creating and managing swarms of collaborative agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agents: Union[
|
||||||
|
List[Agent], List[Tuple[Agent, List[str]]], None
|
||||||
|
] = None,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
swarm_name: str = "Collaborative Agent Swarm",
|
||||||
|
memory_collection: str = "swarm_memory",
|
||||||
|
):
|
||||||
|
"""Initialize GraphSwarm."""
|
||||||
|
self.graph = nx.DiGraph()
|
||||||
|
self.agents: Dict[str, Agent] = {}
|
||||||
|
self.dependencies: Dict[str, List[str]] = {}
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self.swarm_name = swarm_name
|
||||||
|
self.memory_collection = memory_collection
|
||||||
|
self.memory = SwarmMemory(collection_name=memory_collection)
|
||||||
|
|
||||||
|
if agents:
|
||||||
|
self.initialize_agents(agents)
|
||||||
|
|
||||||
|
logger.info(f"Initialized GraphSwarm: {swarm_name}")
|
||||||
|
|
||||||
|
def initialize_agents(
|
||||||
|
self,
|
||||||
|
agents: Union[List[Agent], List[Tuple[Agent, List[str]]]],
|
||||||
|
):
|
||||||
|
"""Initialize agents and their dependencies."""
|
||||||
|
try:
|
||||||
|
# Handle list of Agents or (Agent, dependencies) tuples
|
||||||
|
for item in agents:
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
agent, dependencies = item
|
||||||
|
else:
|
||||||
|
agent, dependencies = item, []
|
||||||
|
|
||||||
|
if not isinstance(agent, Agent):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected Agent object, got {type(agent)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.agents[agent.agent_name] = agent
|
||||||
|
self.dependencies[agent.agent_name] = dependencies
|
||||||
|
self.graph.add_node(agent.agent_name, agent=agent)
|
||||||
|
|
||||||
|
# Add dependencies
|
||||||
|
for dep in dependencies:
|
||||||
|
if dep not in self.agents:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dependency {dep} not found for agent {agent.agent_name}"
|
||||||
|
)
|
||||||
|
self.graph.add_edge(dep, agent.agent_name)
|
||||||
|
|
||||||
|
self._validate_graph()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize agents: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _validate_graph(self):
|
||||||
|
"""Validate the agent dependency graph."""
|
||||||
|
if not self.graph.nodes():
|
||||||
|
raise ValueError("No agents added to swarm")
|
||||||
|
|
||||||
|
if not nx.is_directed_acyclic_graph(self.graph):
|
||||||
|
cycles = list(nx.simple_cycles(self.graph))
|
||||||
|
raise ValueError(
|
||||||
|
f"Agent dependency graph contains cycles: {cycles}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_agent_role_description(self, agent_name: str) -> str:
|
||||||
|
"""Generate a description of the agent's role in the swarm."""
|
||||||
|
predecessors = list(self.graph.predecessors(agent_name))
|
||||||
|
successors = list(self.graph.successors(agent_name))
|
||||||
|
position = (
|
||||||
|
"initial"
|
||||||
|
if not predecessors
|
||||||
|
else ("final" if not successors else "intermediate")
|
||||||
|
)
|
||||||
|
|
||||||
|
role = f"""You are {agent_name}, a specialized agent in the {self.swarm_name}.
|
||||||
|
Position: {position} agent in the workflow
|
||||||
|
|
||||||
|
Your relationships:"""
|
||||||
|
|
||||||
|
if predecessors:
|
||||||
|
role += (
|
||||||
|
f"\nYou receive input from: {', '.join(predecessors)}"
|
||||||
|
)
|
||||||
|
if successors:
|
||||||
|
role += f"\nYour output will be used by: {', '.join(successors)}"
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
def _generate_workflow_context(self) -> str:
|
||||||
|
"""Generate a description of the entire workflow."""
|
||||||
|
execution_order = list(nx.topological_sort(self.graph))
|
||||||
|
|
||||||
|
workflow = f"""Workflow Overview of {self.swarm_name}:
|
||||||
|
|
||||||
|
Processing Order:
|
||||||
|
{' -> '.join(execution_order)}
|
||||||
|
|
||||||
|
Agent Roles:
|
||||||
|
"""
|
||||||
|
|
||||||
|
for agent_name in execution_order:
|
||||||
|
predecessors = list(self.graph.predecessors(agent_name))
|
||||||
|
successors = list(self.graph.successors(agent_name))
|
||||||
|
|
||||||
|
workflow += f"\n\n{agent_name}:"
|
||||||
|
if predecessors:
|
||||||
|
workflow += (
|
||||||
|
f"\n- Receives from: {', '.join(predecessors)}"
|
||||||
|
)
|
||||||
|
if successors:
|
||||||
|
workflow += f"\n- Sends to: {', '.join(successors)}"
|
||||||
|
if not predecessors and not successors:
|
||||||
|
workflow += "\n- Independent agent"
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def _build_agent_prompt(
|
||||||
|
self, agent_name: str, task: str, context: Dict = None
|
||||||
|
) -> str:
|
||||||
|
"""Build a comprehensive prompt for the agent including role and context."""
|
||||||
|
prompt_parts = [
|
||||||
|
self._get_agent_role_description(agent_name),
|
||||||
|
"\nWorkflow Context:",
|
||||||
|
self._generate_workflow_context(),
|
||||||
|
"\nYour Task:",
|
||||||
|
task,
|
||||||
|
]
|
||||||
|
|
||||||
|
if context:
|
||||||
|
prompt_parts.extend(
|
||||||
|
["\nContext from Previous Agents:", str(context)]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_parts.extend(
|
||||||
|
[
|
||||||
|
"\nInstructions:",
|
||||||
|
"1. Process the task according to your role",
|
||||||
|
"2. Consider the input from previous agents when available",
|
||||||
|
"3. Provide clear, structured output",
|
||||||
|
"4. Remember that your output will be used by subsequent agents",
|
||||||
|
"\nResponse Guidelines:",
|
||||||
|
"- Provide clear, well-organized output",
|
||||||
|
"- Include relevant details and insights",
|
||||||
|
"- Highlight key findings",
|
||||||
|
"- Flag any uncertainties or issues",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(prompt_parts)
|
||||||
|
|
||||||
|
async def _execute_agent(
|
||||||
|
self, agent_name: str, task: str, context: Dict = None
|
||||||
|
) -> AgentOutput:
|
||||||
|
"""Execute a single agent."""
|
||||||
|
start_time = time.time()
|
||||||
|
agent = self.agents[agent_name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build comprehensive prompt
|
||||||
|
full_prompt = self._build_agent_prompt(
|
||||||
|
agent_name, task, context
|
||||||
|
)
|
||||||
|
logger.debug(f"Prompt for {agent_name}:\n{full_prompt}")
|
||||||
|
|
||||||
|
# Execute agent
|
||||||
|
output = await asyncio.to_thread(agent.run, full_prompt)
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
agent_name=agent_name,
|
||||||
|
output=output,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
metadata={
|
||||||
|
"task": task,
|
||||||
|
"context": context,
|
||||||
|
"position_in_workflow": list(
|
||||||
|
nx.topological_sort(self.graph)
|
||||||
|
).index(agent_name),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error executing agent {agent_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
return AgentOutput(
|
||||||
|
agent_name=agent_name,
|
||||||
|
output=None,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
error=str(e),
|
||||||
|
metadata={"task": task},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, task: str) -> SwarmOutput:
|
||||||
|
"""
|
||||||
|
Execute the entire swarm of agents with memory integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Initial task to execute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SwarmOutput: Structured output from all agents
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {}
|
||||||
|
success = True
|
||||||
|
error = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get similar past executions
|
||||||
|
similar_executions = self.memory.get_similar_executions(
|
||||||
|
task, limit=3
|
||||||
|
)
|
||||||
|
optimal_sequence = self.memory.get_optimal_sequence(task)
|
||||||
|
|
||||||
|
# Get base execution order
|
||||||
|
base_execution_order = list(
|
||||||
|
nx.topological_sort(self.graph)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine final execution order
|
||||||
|
if optimal_sequence and all(
|
||||||
|
agent in base_execution_order
|
||||||
|
for agent in optimal_sequence
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Using optimal sequence from memory: {optimal_sequence}"
|
||||||
|
)
|
||||||
|
execution_order = optimal_sequence
|
||||||
|
else:
|
||||||
|
execution_order = base_execution_order
|
||||||
|
|
||||||
|
# Get historical context if available
|
||||||
|
historical_context = {}
|
||||||
|
if similar_executions:
|
||||||
|
best_execution = similar_executions[0]
|
||||||
|
if best_execution["success"]:
|
||||||
|
historical_context = {
|
||||||
|
"similar_task": best_execution["task"],
|
||||||
|
"previous_outputs": best_execution["outputs"],
|
||||||
|
"execution_time": best_execution[
|
||||||
|
"execution_time"
|
||||||
|
],
|
||||||
|
"success_patterns": self._extract_success_patterns(
|
||||||
|
similar_executions
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute agents in order
|
||||||
|
for agent_name in execution_order:
|
||||||
|
try:
|
||||||
|
# Get context from dependencies and history
|
||||||
|
agent_context = {
|
||||||
|
"dependencies": {
|
||||||
|
dep: outputs[dep].output
|
||||||
|
for dep in self.graph.predecessors(
|
||||||
|
agent_name
|
||||||
|
)
|
||||||
|
if dep in outputs
|
||||||
|
},
|
||||||
|
"historical": historical_context,
|
||||||
|
"position": execution_order.index(agent_name),
|
||||||
|
"total_agents": len(execution_order),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute agent with enhanced context
|
||||||
|
output = await self._execute_agent(
|
||||||
|
agent_name, task, agent_context
|
||||||
|
)
|
||||||
|
outputs[agent_name] = output
|
||||||
|
|
||||||
|
# Update historical context with current execution
|
||||||
|
if output.output:
|
||||||
|
historical_context.update(
|
||||||
|
{
|
||||||
|
f"current_{agent_name}_output": output.output
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
if output.error:
|
||||||
|
success = False
|
||||||
|
error = f"Agent {agent_name} failed: {output.error}"
|
||||||
|
|
||||||
|
# Try to recover using memory
|
||||||
|
if similar_executions:
|
||||||
|
recovery_output = self._attempt_recovery(
|
||||||
|
agent_name, task, similar_executions
|
||||||
|
)
|
||||||
|
if recovery_output:
|
||||||
|
outputs[agent_name] = recovery_output
|
||||||
|
success = True
|
||||||
|
error = None
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as agent_error:
|
||||||
|
logger.error(
|
||||||
|
f"Error executing agent {agent_name}: {str(agent_error)}"
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
error = f"Agent {agent_name} failed: {str(agent_error)}"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create result
|
||||||
|
result = SwarmOutput(
|
||||||
|
outputs=outputs,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
success=success,
|
||||||
|
error=error,
|
||||||
|
metadata={
|
||||||
|
"task": task,
|
||||||
|
"used_optimal_sequence": optimal_sequence
|
||||||
|
is not None,
|
||||||
|
"similar_executions_found": len(
|
||||||
|
similar_executions
|
||||||
|
),
|
||||||
|
"execution_order": execution_order,
|
||||||
|
"historical_context_used": bool(
|
||||||
|
historical_context
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store execution in memory
|
||||||
|
await self._store_execution_async(task, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Swarm execution failed: {str(e)}")
|
||||||
|
return SwarmOutput(
|
||||||
|
outputs=outputs,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
metadata={"task": task},
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, task: str) -> SwarmOutput:
|
||||||
|
"""Synchronous interface to execute the swarm."""
|
||||||
|
return asyncio.run(self.execute(task))
|
||||||
|
|
||||||
|
def _extract_success_patterns(
|
||||||
|
self, similar_executions: List[Dict]
|
||||||
|
) -> Dict:
|
||||||
|
"""Extract success patterns from similar executions."""
|
||||||
|
patterns = {}
|
||||||
|
successful_execs = [
|
||||||
|
ex for ex in similar_executions if ex["success"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if successful_execs:
|
||||||
|
patterns = {
|
||||||
|
"common_sequences": self._find_common_sequences(
|
||||||
|
successful_execs
|
||||||
|
),
|
||||||
|
"avg_execution_time": sum(
|
||||||
|
ex["execution_time"] for ex in successful_execs
|
||||||
|
)
|
||||||
|
/ len(successful_execs),
|
||||||
|
"successful_strategies": self._extract_strategies(
|
||||||
|
successful_execs
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def _attempt_recovery(
|
||||||
|
self,
|
||||||
|
failed_agent: str,
|
||||||
|
task: str,
|
||||||
|
similar_executions: List[Dict],
|
||||||
|
) -> Optional[AgentOutput]:
|
||||||
|
"""Attempt to recover from failure using memory."""
|
||||||
|
for execution in similar_executions:
|
||||||
|
if (
|
||||||
|
execution["success"]
|
||||||
|
and failed_agent in execution["outputs"]
|
||||||
|
):
|
||||||
|
historical_output = execution["outputs"][failed_agent]
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
agent_name=failed_agent,
|
||||||
|
output=historical_output["output"],
|
||||||
|
execution_time=historical_output[
|
||||||
|
"execution_time"
|
||||||
|
],
|
||||||
|
metadata={
|
||||||
|
"recovered_from_memory": True,
|
||||||
|
"original_task": execution["task"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _store_execution_async(
|
||||||
|
self, task: str, result: SwarmOutput
|
||||||
|
):
|
||||||
|
"""Asynchronously store execution in memory."""
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.memory.store_execution, task, result
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to store execution in memory: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_agent(self, agent: Agent, dependencies: List[str] = None):
|
||||||
|
"""Add a new agent to the swarm."""
|
||||||
|
dependencies = dependencies or []
|
||||||
|
self.agents[agent.agent_name] = agent
|
||||||
|
self.dependencies[agent.agent_name] = dependencies
|
||||||
|
self.graph.add_node(agent.agent_name, agent=agent)
|
||||||
|
|
||||||
|
for dep in dependencies:
|
||||||
|
if dep not in self.agents:
|
||||||
|
raise ValueError(f"Dependency {dep} not found")
|
||||||
|
self.graph.add_edge(dep, agent.agent_name)
|
||||||
|
|
||||||
|
self._validate_graph()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# Create agents
|
||||||
|
data_collector = Agent(
|
||||||
|
agent_name="Market-Data-Collector",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trend_analyzer = Agent(
|
||||||
|
agent_name="Market-Trend-Analyzer",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
report_generator = Agent(
|
||||||
|
agent_name="Investment-Report-Generator",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create swarm
|
||||||
|
swarm = GraphSwarm(
|
||||||
|
agents=[
|
||||||
|
(data_collector, []),
|
||||||
|
(trend_analyzer, ["Market-Data-Collector"]),
|
||||||
|
(report_generator, ["Market-Trend-Analyzer"]),
|
||||||
|
],
|
||||||
|
swarm_name="Market Analysis Intelligence Network",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the swarm
|
||||||
|
result = swarm.run(
|
||||||
|
"Analyze current market trends for tech stocks and provide investment recommendations"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Execution success: {result.success}")
|
||||||
|
print(f"Total time: {result.execution_time:.2f} seconds")
|
||||||
|
|
||||||
|
for agent_name, output in result.outputs.items():
|
||||||
|
print(f"\nAgent: {agent_name}")
|
||||||
|
print(f"Output: {output.output}")
|
||||||
|
if output.error:
|
||||||
|
print(f"Error: {output.error}")
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(error)
|
||||||
|
raise error
|
@ -0,0 +1,276 @@
|
|||||||
|
import asyncio
|
||||||
|
import pulsar
|
||||||
|
|
||||||
|
from pulsar import ConsumerType
|
||||||
|
from loguru import logger
|
||||||
|
from swarms import Agent
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class ScalableAsyncAgentSwarm:
|
||||||
|
"""
|
||||||
|
A scalable, asynchronous swarm of agents leveraging Apache Pulsar for inter-agent communication.
|
||||||
|
Provides load balancing, health monitoring, dead letter queues, and centralized logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pulsar_url: str,
|
||||||
|
topic: str,
|
||||||
|
dlq_topic: str,
|
||||||
|
agents_config: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the async swarm with agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pulsar_url (str): The URL of the Apache Pulsar broker.
|
||||||
|
topic (str): The main topic for task distribution.
|
||||||
|
dlq_topic (str): The Dead Letter Queue topic for failed messages.
|
||||||
|
agents_config (List[Dict[str, Any]]): List of agent configurations with `name`, `description`, and `model_name`.
|
||||||
|
"""
|
||||||
|
self.pulsar_url = pulsar_url
|
||||||
|
self.topic = topic
|
||||||
|
self.dlq_topic = dlq_topic
|
||||||
|
self.agents_config = agents_config
|
||||||
|
self.client = pulsar.Client(pulsar_url)
|
||||||
|
self.consumer = self.client.subscribe(
|
||||||
|
topic,
|
||||||
|
subscription_name="swarm-task-sub",
|
||||||
|
consumer_type=ConsumerType.Shared,
|
||||||
|
)
|
||||||
|
self.dlq_producer = self.client.create_producer(dlq_topic)
|
||||||
|
self.response_logger = []
|
||||||
|
self.agents = [
|
||||||
|
self.create_agent(config) for config in agents_config
|
||||||
|
]
|
||||||
|
self.agent_index = 0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Swarm initialized with agents: {}",
|
||||||
|
[agent["name"] for agent in agents_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_agent(
|
||||||
|
self, agent_config: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Creates a new agent configuration with asynchronous capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_config (Dict[str, Any]): Configuration dictionary with agent details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing agent metadata and functionality.
|
||||||
|
"""
|
||||||
|
agent_name = agent_config["name"]
|
||||||
|
description = agent_config["description"]
|
||||||
|
model_name = agent_config.get("model_name", "gpt-4o-mini")
|
||||||
|
|
||||||
|
class AsyncAgent:
|
||||||
|
"""
|
||||||
|
An asynchronous agent that processes tasks and communicates via Apache Pulsar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, name: str, description: str, model_name: str
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name=name,
|
||||||
|
model_name=model_name,
|
||||||
|
max_loops="auto",
|
||||||
|
interactive=True,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized agent '{name}' - {description}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_task(
|
||||||
|
self, message: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Processes a single task using the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The task message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: JSON-formatted response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Agent {self.name} processing task: {message}"
|
||||||
|
)
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
self.agent.run, message
|
||||||
|
)
|
||||||
|
logger.info(f"Agent {self.name} completed task.")
|
||||||
|
return {
|
||||||
|
"agent_name": self.name,
|
||||||
|
"response": response,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Agent {self.name} encountered an error: {e}"
|
||||||
|
)
|
||||||
|
return {"agent_name": self.name, "error": str(e)}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": agent_name,
|
||||||
|
"instance": AsyncAgent(
|
||||||
|
agent_name, description, model_name
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def distribute_task(self, message: str):
|
||||||
|
"""
|
||||||
|
Distributes a task to the next available agent using round-robin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The task message.
|
||||||
|
"""
|
||||||
|
agent = self.agents[self.agent_index]
|
||||||
|
self.agent_index = (self.agent_index + 1) % len(self.agents)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await agent["instance"].process_task(message)
|
||||||
|
self.log_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing task by agent {agent['name']}: {e}"
|
||||||
|
)
|
||||||
|
self.send_to_dlq(message)
|
||||||
|
|
||||||
|
async def monitor_health(self):
|
||||||
|
"""
|
||||||
|
Periodically monitors the health of agents.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
logger.info("Performing health check for all agents.")
|
||||||
|
for agent in self.agents:
|
||||||
|
logger.info(f"Agent {agent['name']} is online.")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
def send_to_dlq(self, message: str):
|
||||||
|
"""
|
||||||
|
Sends a failed message to the Dead Letter Queue (DLQ).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to send to the DLQ.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.dlq_producer.send(message.encode("utf-8"))
|
||||||
|
logger.info("Message sent to Dead Letter Queue.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send message to DLQ: {e}")
|
||||||
|
|
||||||
|
def log_response(self, response: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Logs the response to a centralized list for later analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (Dict[str, Any]): The agent's response.
|
||||||
|
"""
|
||||||
|
self.response_logger.append(response)
|
||||||
|
logger.info(f"Response logged: {response}")
|
||||||
|
|
||||||
|
async def listen_and_distribute(self):
|
||||||
|
"""
|
||||||
|
Listens to the main Pulsar topic and distributes tasks to agents.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
msg = self.consumer.receive()
|
||||||
|
try:
|
||||||
|
message = msg.data().decode("utf-8")
|
||||||
|
logger.info(f"Received task: {message}")
|
||||||
|
await self.distribute_task(message)
|
||||||
|
self.consumer.acknowledge(msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message: {e}")
|
||||||
|
self.send_to_dlq(msg.data().decode("utf-8"))
|
||||||
|
self.consumer.negative_acknowledge(msg)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""
|
||||||
|
Runs the swarm asynchronously with health monitoring and task distribution.
|
||||||
|
"""
|
||||||
|
logger.info("Starting the async swarm...")
|
||||||
|
task_listener = asyncio.create_task(
|
||||||
|
self.listen_and_distribute()
|
||||||
|
)
|
||||||
|
health_monitor = asyncio.create_task(self.monitor_health())
|
||||||
|
await asyncio.gather(task_listener, health_monitor)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""
|
||||||
|
Safely shuts down the swarm and logs all responses.
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down the swarm...")
|
||||||
|
self.client.close()
|
||||||
|
with open("responses.json", "w") as f:
|
||||||
|
json.dump(self.response_logger, f, indent=4)
|
||||||
|
logger.info("Responses saved to 'responses.json'.")
|
||||||
|
|
||||||
|
|
||||||
|
# from scalable_agent_swarm import ScalableAsyncAgentSwarm # Assuming your swarm class is saved here
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example Configuration
|
||||||
|
PULSAR_URL = "pulsar://localhost:6650"
|
||||||
|
TOPIC = "stock-analysis"
|
||||||
|
DLQ_TOPIC = "stock-analysis-dlq"
|
||||||
|
|
||||||
|
# Agents configuration
|
||||||
|
AGENTS_CONFIG = [
|
||||||
|
{
|
||||||
|
"name": "Stock-Analysis-Agent-1",
|
||||||
|
"description": "Analyzes stock trends.",
|
||||||
|
"model_name": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Stock-News-Agent",
|
||||||
|
"description": "Summarizes stock news.",
|
||||||
|
"model_name": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Tech-Trends-Agent",
|
||||||
|
"description": "Tracks tech sector trends.",
|
||||||
|
"model_name": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tasks to send
|
||||||
|
TASKS = [
|
||||||
|
"Analyze the trend for tech stocks in Q4 2024",
|
||||||
|
"Summarize the latest news on the S&P 500",
|
||||||
|
"Identify the top-performing sectors in the stock market",
|
||||||
|
"Provide a forecast for AI-related stocks for 2025",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize and run the swarm
|
||||||
|
swarm = ScalableAsyncAgentSwarm(
|
||||||
|
PULSAR_URL, TOPIC, DLQ_TOPIC, AGENTS_CONFIG
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Run the swarm in the background
|
||||||
|
swarm_task = asyncio.create_task(swarm.run())
|
||||||
|
|
||||||
|
# Send tasks to the topic
|
||||||
|
client = pulsar.Client(PULSAR_URL)
|
||||||
|
producer = client.create_producer(TOPIC)
|
||||||
|
|
||||||
|
for task in TASKS:
|
||||||
|
producer.send(task.encode("utf-8"))
|
||||||
|
print(f"Sent task: {task}")
|
||||||
|
|
||||||
|
producer.close()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
# Keep the swarm running
|
||||||
|
asyncio.run(swarm_task)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
swarm.shutdown()
|
@ -1,203 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
import inspect
|
|
||||||
from functools import partial
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
class NameResolver:
|
|
||||||
"""Utility class for resolving names of various objects"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_name(obj: Any, default: str = "unnamed_callable") -> str:
|
|
||||||
"""
|
|
||||||
Get the name of any object with multiple fallback strategies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The object to get the name from
|
|
||||||
default: Default name if all strategies fail
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The resolved name
|
|
||||||
"""
|
|
||||||
strategies = [
|
|
||||||
# Try getting __name__ attribute
|
|
||||||
lambda x: getattr(x, "__name__", None),
|
|
||||||
# Try getting class name
|
|
||||||
lambda x: (
|
|
||||||
x.__class__.__name__
|
|
||||||
if hasattr(x, "__class__")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
# Try getting function name if it's a partial
|
|
||||||
lambda x: (
|
|
||||||
x.func.__name__ if isinstance(x, partial) else None
|
|
||||||
),
|
|
||||||
# Try getting the name from the class's type
|
|
||||||
lambda x: type(x).__name__,
|
|
||||||
# Try getting qualname
|
|
||||||
lambda x: getattr(x, "__qualname__", None),
|
|
||||||
# Try getting the module and class name
|
|
||||||
lambda x: (
|
|
||||||
f"{x.__module__}.{x.__class__.__name__}"
|
|
||||||
if hasattr(x, "__module__")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
# For async functions
|
|
||||||
lambda x: (
|
|
||||||
x.__name__ if inspect.iscoroutinefunction(x) else None
|
|
||||||
),
|
|
||||||
# For classes with custom __str__
|
|
||||||
lambda x: (
|
|
||||||
str(x)
|
|
||||||
if hasattr(x, "__str__")
|
|
||||||
and x.__str__ != object.__str__
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
# For wrapped functions
|
|
||||||
lambda x: (
|
|
||||||
getattr(x, "__wrapped__", None).__name__
|
|
||||||
if hasattr(x, "__wrapped__")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Try each strategy
|
|
||||||
for strategy in strategies:
|
|
||||||
try:
|
|
||||||
name = strategy(obj)
|
|
||||||
if name and isinstance(name, str):
|
|
||||||
return name.replace(" ", "_").replace("-", "_")
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Return default if all strategies fail
|
|
||||||
return default
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_callable_details(obj: Any) -> dict:
|
|
||||||
"""
|
|
||||||
Get detailed information about a callable object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Dictionary containing:
|
|
||||||
- name: The resolved name
|
|
||||||
- type: The type of callable
|
|
||||||
- signature: The signature if available
|
|
||||||
- module: The module name if available
|
|
||||||
- doc: The docstring if available
|
|
||||||
"""
|
|
||||||
details = {
|
|
||||||
"name": NameResolver.get_name(obj),
|
|
||||||
"type": "unknown",
|
|
||||||
"signature": None,
|
|
||||||
"module": getattr(obj, "__module__", "unknown"),
|
|
||||||
"doc": inspect.getdoc(obj)
|
|
||||||
or "No documentation available",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine the type
|
|
||||||
if inspect.isclass(obj):
|
|
||||||
details["type"] = "class"
|
|
||||||
elif inspect.iscoroutinefunction(obj):
|
|
||||||
details["type"] = "async_function"
|
|
||||||
elif inspect.isfunction(obj):
|
|
||||||
details["type"] = "function"
|
|
||||||
elif isinstance(obj, partial):
|
|
||||||
details["type"] = "partial"
|
|
||||||
elif callable(obj):
|
|
||||||
details["type"] = "callable"
|
|
||||||
|
|
||||||
# Try to get signature
|
|
||||||
try:
|
|
||||||
details["signature"] = str(inspect.signature(obj))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
details["signature"] = "Unknown signature"
|
|
||||||
|
|
||||||
return details
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_safe_name(cls, obj: Any, max_retries: int = 3) -> str:
|
|
||||||
"""
|
|
||||||
Safely get a name with retries and validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: Object to get name from
|
|
||||||
max_retries: Maximum number of retry attempts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A valid name string
|
|
||||||
"""
|
|
||||||
retries = 0
|
|
||||||
last_error = None
|
|
||||||
|
|
||||||
while retries < max_retries:
|
|
||||||
try:
|
|
||||||
name = cls.get_name(obj)
|
|
||||||
|
|
||||||
# Validate and clean the name
|
|
||||||
if name:
|
|
||||||
# Remove invalid characters
|
|
||||||
clean_name = "".join(
|
|
||||||
c
|
|
||||||
for c in name
|
|
||||||
if c.isalnum() or c in ["_", "."]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure it starts with a letter or underscore
|
|
||||||
if (
|
|
||||||
not clean_name[0].isalpha()
|
|
||||||
and clean_name[0] != "_"
|
|
||||||
):
|
|
||||||
clean_name = f"_{clean_name}"
|
|
||||||
|
|
||||||
return clean_name
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
retries += 1
|
|
||||||
|
|
||||||
# If all retries failed, generate a unique fallback name
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
fallback = f"callable_{uuid.uuid4().hex[:8]}"
|
|
||||||
logging.warning(
|
|
||||||
f"Failed to get name after {max_retries} retries. Using fallback: {fallback}. "
|
|
||||||
f"Last error: {str(last_error)}"
|
|
||||||
)
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# def test_resolver():
|
|
||||||
# # Test cases
|
|
||||||
# class TestClass:
|
|
||||||
# def method(self):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# async def async_func():
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# test_cases = [
|
|
||||||
# TestClass, # Class
|
|
||||||
# TestClass(), # Instance
|
|
||||||
# async_func, # Async function
|
|
||||||
# lambda x: x, # Lambda
|
|
||||||
# partial(print, end=""), # Partial
|
|
||||||
# TestClass.method, # Method
|
|
||||||
# print, # Built-in function
|
|
||||||
# str, # Built-in class
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# resolver = NameResolver()
|
|
||||||
|
|
||||||
# print("\nName Resolution Results:")
|
|
||||||
# print("-" * 50)
|
|
||||||
# for obj in test_cases:
|
|
||||||
# details = resolver.get_callable_details(obj)
|
|
||||||
# safe_name = resolver.get_safe_name(obj)
|
|
||||||
# print(f"\nObject: {obj}")
|
|
||||||
# print(f"Safe Name: {safe_name}")
|
|
||||||
# print(f"Details: {details}")
|
|
||||||
|
|
||||||
# test_resolver()
|
|
@ -0,0 +1,105 @@
|
|||||||
|
try:
|
||||||
|
from litellm import completion
|
||||||
|
except ImportError:
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
subprocess.check_call(["pip", "install", "litellm"])
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM:
|
||||||
|
"""
|
||||||
|
This class represents a LiteLLM.
|
||||||
|
It is used to interact with the LLM model for various tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-4o",
|
||||||
|
system_prompt: str = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LiteLLM with the given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str, optional): The name of the model to use. Defaults to "gpt-4o".
|
||||||
|
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||||
|
stream (bool, optional): Whether to stream the output. Defaults to False.
|
||||||
|
temperature (float, optional): The temperature for the model. Defaults to 0.5.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.stream = stream
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def _prepare_messages(self, task: str) -> list:
|
||||||
|
"""
|
||||||
|
Prepare the messages for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to prepare messages for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of messages prepared for the task.
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if self.system_prompt: # Check if system_prompt is not None
|
||||||
|
messages.append(
|
||||||
|
{"role": "system", "content": self.system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": task})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the LLM model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments to pass to the model.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
"""
|
||||||
|
messages = self._prepare_messages(task)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
stream=self.stream,
|
||||||
|
temperature=self.temperature,
|
||||||
|
# max_completion_tokens=self.max_tokens,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
content = response.choices[
|
||||||
|
0
|
||||||
|
].message.content # Accessing the content
|
||||||
|
return content
|
||||||
|
|
||||||
|
def __call__(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Call the LLM model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments to pass to the model.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
"""
|
||||||
|
return self.run(task, *args, **kwargs)
|
@ -0,0 +1,73 @@
|
|||||||
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
import pygame
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITTS:
|
||||||
|
"""
|
||||||
|
A class to interact with OpenAI API and play the generated audio with improved streaming capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"), *args, **kwargs
|
||||||
|
)
|
||||||
|
pygame.init()
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, task: str, play_sound: bool = True, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run a task with the OpenAI API and optionally play the generated audio with improved streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be executed.
|
||||||
|
play_sound (bool): If True, play the generated audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.audio.speech.create(
|
||||||
|
model="tts-1",
|
||||||
|
voice="nova",
|
||||||
|
input=task,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
audio_url = response["url"]
|
||||||
|
logger.info("Task completed successfully.")
|
||||||
|
|
||||||
|
if play_sound:
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, suffix=".mp3"
|
||||||
|
) as tmp_file:
|
||||||
|
with requests.get(audio_url, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
tmp_file.write(chunk)
|
||||||
|
pygame.mixer.music.load(tmp_file.name)
|
||||||
|
pygame.mixer.music.play()
|
||||||
|
while pygame.mixer.music.get_busy():
|
||||||
|
pygame.time.Clock().tick(10)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during task execution: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# client = OpenAITTS(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
# client.run("Hello world! This is a streaming test.", play_sound=True)
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_speech(
|
||||||
|
task: str, play_sound: bool = True, *args, **kwargs
|
||||||
|
):
|
||||||
|
out = OpenAITTS().run(
|
||||||
|
task, play_sound=play_sound, *args, **kwargs
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# print(text_to_speech(task="hello"))
|
@ -1,50 +1,64 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def extract_code_from_markdown(markdown_content: str) -> str:
|
def extract_code_blocks_with_language(markdown_text: str):
|
||||||
"""
|
"""
|
||||||
Extracts code blocks from a Markdown string and returns them as a single string.
|
Extracts all code blocks from Markdown text along with their languages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- markdown_content (str): The Markdown content as a string.
|
markdown_text (str): The input Markdown text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: A single string containing all the code blocks separated by newlines.
|
list[dict]: A list of dictionaries, each containing:
|
||||||
|
- 'language': The detected language (or 'plaintext' if none specified).
|
||||||
|
- 'content': The content of the code block.
|
||||||
"""
|
"""
|
||||||
# Regular expression for fenced code blocks with optional language specifier
|
# Regex pattern to match code blocks and optional language specifiers
|
||||||
pattern = r"```(?:\w+\n)?(.*?)```"
|
pattern = r"```(\w+)?\n(.*?)```"
|
||||||
|
|
||||||
# Check if markdown_content is a string
|
# Find all matches (language and content)
|
||||||
if not isinstance(markdown_content, str):
|
matches = re.findall(pattern, markdown_text, re.DOTALL)
|
||||||
raise TypeError("markdown_content must be a string")
|
|
||||||
|
|
||||||
# Find all matches of the pattern
|
# Parse results
|
||||||
matches = re.finditer(pattern, markdown_content, re.DOTALL)
|
|
||||||
|
|
||||||
# Extract the content inside the backticks
|
|
||||||
code_blocks = []
|
code_blocks = []
|
||||||
for match in matches:
|
for language, content in matches:
|
||||||
code_block = match.group(1).strip()
|
language = (
|
||||||
# Remove any leading or trailing whitespace from the code block
|
language.strip() if language else "plaintext"
|
||||||
code_block = code_block.strip()
|
) # Default to 'plaintext'
|
||||||
# Remove any empty lines from the code block
|
code_blocks.append(
|
||||||
code_block = "\n".join(
|
{"language": language, "content": content.strip()}
|
||||||
[line for line in code_block.split("\n") if line.strip()]
|
|
||||||
)
|
)
|
||||||
code_blocks.append(code_block)
|
|
||||||
|
|
||||||
# Concatenate all code blocks separated by newlines
|
return code_blocks
|
||||||
if code_blocks:
|
|
||||||
return "\n\n".join(code_blocks)
|
|
||||||
else:
|
def extract_code_from_markdown(
|
||||||
return ""
|
markdown_text: str, language: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Extracts content of code blocks for a specific language or all blocks if no language specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
markdown_text (str): The input Markdown text.
|
||||||
|
language (str, optional): The language to filter by (e.g., 'yaml', 'python').
|
||||||
|
|
||||||
# example = """
|
Returns:
|
||||||
# hello im an agent
|
str: The concatenated content of matched code blocks or an empty string if none found.
|
||||||
# ```bash
|
"""
|
||||||
# pip install swarms
|
# Get all code blocks with detected languages
|
||||||
# ```
|
code_blocks = extract_code_blocks_with_language(markdown_text)
|
||||||
# """
|
|
||||||
|
# Filter by language if specified
|
||||||
|
if language:
|
||||||
|
code_blocks = [
|
||||||
|
block["content"]
|
||||||
|
for block in code_blocks
|
||||||
|
if block["language"] == language
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
code_blocks = [
|
||||||
|
block["content"] for block in code_blocks
|
||||||
|
] # Include all blocks
|
||||||
|
|
||||||
# print(extract_code_from_markdown(example)) # Output: { "type": "function", "function": { "name": "fetch_financial_news", "parameters": { "query": "Nvidia news", "num_articles": 5 } } }
|
# Return concatenated content
|
||||||
|
return "\n\n".join(code_blocks) if code_blocks else ""
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
def remove_whitespace_from_json(json_string: str) -> str:
|
|
||||||
"""
|
|
||||||
Removes unnecessary whitespace from a JSON string.
|
|
||||||
|
|
||||||
This function parses the JSON string into a Python object and then
|
|
||||||
serializes it back into a JSON string without unnecessary whitespace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_string (str): The JSON string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The JSON string with whitespace removed.
|
|
||||||
"""
|
|
||||||
parsed = json.loads(json_string)
|
|
||||||
return json.dumps(parsed, separators=(",", ":"))
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage for JSON
|
|
||||||
# json_string = '{"field1": 123, "field2": "example text"}'
|
|
||||||
# print(remove_whitespace_from_json(json_string))
|
|
||||||
|
|
||||||
|
|
||||||
def remove_whitespace_from_yaml(yaml_string: str) -> str:
|
|
||||||
"""
|
|
||||||
Removes unnecessary whitespace from a YAML string.
|
|
||||||
|
|
||||||
This function parses the YAML string into a Python object and then
|
|
||||||
serializes it back into a YAML string with minimized whitespace.
|
|
||||||
Note: This might change the representation style of YAML data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_string (str): The YAML string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The YAML string with whitespace reduced.
|
|
||||||
"""
|
|
||||||
parsed = yaml.safe_load(yaml_string)
|
|
||||||
return yaml.dump(parsed, default_flow_style=True)
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage for YAML
|
|
||||||
# yaml_string = """
|
|
||||||
# field1: 123
|
|
||||||
# field2: example text
|
|
||||||
# """
|
|
||||||
# print(remove_whitespace_from_yaml(yaml_string))
|
|
@ -1,292 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
from loguru import logger
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StarAttentionConfig:
|
|
||||||
"""Configuration for StarAttention module.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
hidden_size: Dimension of the model's hidden states
|
|
||||||
num_attention_heads: Number of attention heads
|
|
||||||
num_hosts: Number of hosts in the distributed system
|
|
||||||
block_size: Size of each context block
|
|
||||||
anchor_size: Size of the anchor block
|
|
||||||
dropout_prob: Dropout probability (default: 0.1)
|
|
||||||
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
|
||||||
"""
|
|
||||||
|
|
||||||
hidden_size: int
|
|
||||||
num_attention_heads: int
|
|
||||||
num_hosts: int
|
|
||||||
block_size: int
|
|
||||||
anchor_size: int
|
|
||||||
dropout_prob: float = 0.1
|
|
||||||
layer_norm_eps: float = 1e-12
|
|
||||||
|
|
||||||
|
|
||||||
class StarAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of Star Attention mechanism for distributed inference.
|
|
||||||
|
|
||||||
The module implements a two-phase attention mechanism:
|
|
||||||
1. Local Context Encoding with Anchor Blocks
|
|
||||||
2. Query Encoding and Output Generation with Global Attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: StarAttentionConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
|
||||||
f"heads {config.num_attention_heads}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.head_dim = (
|
|
||||||
config.hidden_size // config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize components
|
|
||||||
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.dropout_prob)
|
|
||||||
self.layer_norm = nn.LayerNorm(
|
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
# KV cache for storing computed key/value pairs
|
|
||||||
self.kv_cache = {}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Initialized StarAttention with config: {config}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_heads(
|
|
||||||
self, tensor: torch.Tensor, num_heads: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Split the last dimension into (num_heads, head_dim)."""
|
|
||||||
batch_size, seq_len, _ = tensor.size()
|
|
||||||
tensor = tensor.view(
|
|
||||||
batch_size, seq_len, num_heads, self.head_dim
|
|
||||||
)
|
|
||||||
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
return tensor.transpose(1, 2)
|
|
||||||
|
|
||||||
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Merge the head dimension back into hidden_size."""
|
|
||||||
batch_size, _, seq_len, _ = tensor.size()
|
|
||||||
tensor = tensor.transpose(1, 2)
|
|
||||||
return tensor.reshape(
|
|
||||||
batch_size, seq_len, self.config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def _compute_attention_scores(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Compute attention scores and weighted values."""
|
|
||||||
# Scale dot-product attention
|
|
||||||
scores = torch.matmul(
|
|
||||||
query, key.transpose(-2, -1)
|
|
||||||
) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
||||||
|
|
||||||
# Online softmax computation
|
|
||||||
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
|
||||||
attention_probs = self.dropout(attention_probs)
|
|
||||||
|
|
||||||
context = torch.matmul(attention_probs, value)
|
|
||||||
|
|
||||||
return context, attention_probs
|
|
||||||
|
|
||||||
def phase1_local_context_encoding(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
host_id: int,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Phase 1: Local Context Encoding with Anchor Blocks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
|
||||||
host_id: ID of the current host
|
|
||||||
device: Device to run computations on
|
|
||||||
"""
|
|
||||||
logger.debug(f"Starting Phase 1 on host {host_id}")
|
|
||||||
|
|
||||||
# Calculate block assignments
|
|
||||||
block_start = host_id * self.config.block_size
|
|
||||||
block_end = block_start + self.config.block_size
|
|
||||||
|
|
||||||
# Get local block
|
|
||||||
local_block = input_ids[:, block_start:block_end].to(device)
|
|
||||||
|
|
||||||
# Get anchor block (first block)
|
|
||||||
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute KV pairs for local block
|
|
||||||
local_hidden = self.layer_norm(local_block)
|
|
||||||
local_key = self._split_heads(
|
|
||||||
self.key(local_hidden), self.config.num_attention_heads
|
|
||||||
)
|
|
||||||
local_value = self._split_heads(
|
|
||||||
self.value(local_hidden), self.config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store in KV cache
|
|
||||||
self.kv_cache[host_id] = {
|
|
||||||
"key": local_key,
|
|
||||||
"value": local_value,
|
|
||||||
"anchor_key": (
|
|
||||||
None
|
|
||||||
if host_id == 0
|
|
||||||
else self._split_heads(
|
|
||||||
self.key(self.layer_norm(anchor_block)),
|
|
||||||
self.config.num_attention_heads,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
|
||||||
f"key: {local_key.shape}, value: {local_value.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def phase2_query_encoding(
|
|
||||||
self,
|
|
||||||
query_input: torch.Tensor,
|
|
||||||
host_id: int,
|
|
||||||
is_query_host: bool,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Phase 2: Query Encoding and Output Generation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
|
||||||
host_id: ID of the current host
|
|
||||||
is_query_host: Whether this host is the query host
|
|
||||||
device: Device to run computations on
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output tensor if this is the query host, None otherwise
|
|
||||||
"""
|
|
||||||
logger.debug(f"Starting Phase 2 on host {host_id}")
|
|
||||||
|
|
||||||
# Transform query
|
|
||||||
query_hidden = self.layer_norm(query_input)
|
|
||||||
query = self._split_heads(
|
|
||||||
self.query(query_hidden), self.config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute local attention scores
|
|
||||||
local_context, local_probs = self._compute_attention_scores(
|
|
||||||
query,
|
|
||||||
self.kv_cache[host_id]["key"],
|
|
||||||
self.kv_cache[host_id]["value"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_query_host:
|
|
||||||
# Non-query hosts send their local attention statistics
|
|
||||||
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Query host aggregates attention from all hosts
|
|
||||||
all_attention_probs = [local_probs]
|
|
||||||
for src_rank in range(self.config.num_hosts - 1):
|
|
||||||
probs = torch.empty_like(local_probs)
|
|
||||||
dist.recv(probs, src=src_rank)
|
|
||||||
all_attention_probs.append(probs)
|
|
||||||
|
|
||||||
# Compute global attention
|
|
||||||
torch.mean(torch.stack(all_attention_probs), dim=0)
|
|
||||||
|
|
||||||
# Final output computation
|
|
||||||
output = self._merge_heads(local_context)
|
|
||||||
output = self.dropout(output)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
query_input: torch.Tensor,
|
|
||||||
host_id: int,
|
|
||||||
is_query_host: bool,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Forward pass of the StarAttention module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
|
||||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
|
||||||
host_id: ID of the current host
|
|
||||||
is_query_host: Whether this host is the query host
|
|
||||||
device: Device to run computations on
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output tensor if this is the query host, None otherwise
|
|
||||||
"""
|
|
||||||
# Phase 1: Local Context Encoding
|
|
||||||
self.phase1_local_context_encoding(input_ids, host_id, device)
|
|
||||||
|
|
||||||
# Phase 2: Query Encoding and Output Generation
|
|
||||||
return self.phase2_query_encoding(
|
|
||||||
query_input, host_id, is_query_host, device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Example forward pass
|
|
||||||
config = StarAttentionConfig(
|
|
||||||
hidden_size=768,
|
|
||||||
num_attention_heads=12,
|
|
||||||
num_hosts=3,
|
|
||||||
block_size=512,
|
|
||||||
anchor_size=128,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = StarAttention(config)
|
|
||||||
|
|
||||||
# Example input tensors
|
|
||||||
batch_size = 4
|
|
||||||
seq_len = 512
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_len)
|
|
||||||
) # Random input IDs
|
|
||||||
query_input = torch.randn(
|
|
||||||
batch_size, seq_len, config.hidden_size
|
|
||||||
) # Random query input
|
|
||||||
|
|
||||||
# Example forward pass for query host (host_id = 2)
|
|
||||||
output = model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
query_input=query_input,
|
|
||||||
host_id=2,
|
|
||||||
is_query_host=True,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(output)
|
|
Loading…
Reference in new issue