parent
e6d616b2c3
commit
c2cd677575
@ -1,18 +1,26 @@
|
|||||||
|
/* * Further customization as needed */ */
|
||||||
/* Further customization as needed */
|
|
||||||
|
|
||||||
|
|
||||||
.md-typeset__table {
|
.md-typeset__table {
|
||||||
min-width: 100%;
|
min-width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.md-typeset table:not([class]) {
|
||||||
|
display: table;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dark mode */
|
||||||
|
[data-md-color-scheme="slate"] {
|
||||||
|
--md-default-bg-color: black;
|
||||||
|
}
|
||||||
|
|
||||||
.md-typeset table:not([class]) {
|
.header__ellipsis {
|
||||||
display: table;
|
color: black;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
:root {
|
:root {
|
||||||
--md-primary-fg-color: #EE0F0F;
|
--md-primary-fg-color: #EE0F0F;
|
||||||
--md-primary-fg-color--light: #ECB7B7;
|
--md-primary-fg-color--light: #ECB7B7;
|
||||||
--md-primary-fg-color--dark: #90030C;
|
--md-primary-fg-color--dark: #90030C;
|
||||||
} */
|
} */
|
@ -0,0 +1,308 @@
|
|||||||
|
import os
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
from web3 import Web3
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
import asyncio
|
||||||
|
from loguru import logger
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import csv
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
BLOCKCHAIN_AGENT_PROMPT = """
|
||||||
|
You are an expert blockchain and cryptocurrency analyst with deep knowledge of Ethereum markets and DeFi ecosystems.
|
||||||
|
You have access to real-time ETH price data and transaction information.
|
||||||
|
|
||||||
|
For each transaction, analyze:
|
||||||
|
|
||||||
|
1. MARKET CONTEXT
|
||||||
|
- Current ETH price and what this transaction means in USD terms
|
||||||
|
- How this movement compares to typical market volumes
|
||||||
|
- Whether this could impact ETH price
|
||||||
|
|
||||||
|
2. BEHAVIORAL ANALYSIS
|
||||||
|
- Whether this appears to be institutional, whale, or protocol movement
|
||||||
|
- If this fits any known wallet patterns or behaviors
|
||||||
|
- Signs of smart contract interaction or DeFi activity
|
||||||
|
|
||||||
|
3. RISK & IMPLICATIONS
|
||||||
|
- Potential market impact or price influence
|
||||||
|
- Signs of potential market manipulation or unusual activity
|
||||||
|
- Protocol or DeFi risks if applicable
|
||||||
|
|
||||||
|
4. STRATEGIC INSIGHTS
|
||||||
|
- What traders should know about this movement
|
||||||
|
- Potential chain reactions or follow-up effects
|
||||||
|
- Market opportunities or risks created
|
||||||
|
|
||||||
|
Write naturally but precisely. Focus on actionable insights and important patterns.
|
||||||
|
Your analysis helps traders and researchers understand significant market movements in real-time."""
|
||||||
|
|
||||||
|
|
||||||
|
class EthereumAnalyzer:
|
||||||
|
def __init__(self, min_value_eth: float = 100.0):
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
"eth_analysis.log",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.w3 = Web3(
|
||||||
|
Web3.HTTPProvider(
|
||||||
|
"https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not self.w3.is_connected():
|
||||||
|
raise ConnectionError(
|
||||||
|
"Failed to connect to Ethereum network"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.min_value_eth = min_value_eth
|
||||||
|
self.last_processed_block = self.w3.eth.block_number
|
||||||
|
self.eth_price = self.get_eth_price()
|
||||||
|
self.last_price_update = time.time()
|
||||||
|
|
||||||
|
# Initialize AI agent
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI API key not found in environment variables"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = OpenAIChat(
|
||||||
|
openai_api_key=api_key,
|
||||||
|
model_name="gpt-4",
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name="Ethereum-Analysis-Agent",
|
||||||
|
system_prompt=BLOCKCHAIN_AGENT_PROMPT,
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
verbose=True,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
saved_state_path="eth_agent.json",
|
||||||
|
user_name="eth_analyzer",
|
||||||
|
retry_attempts=1,
|
||||||
|
context_length=200000,
|
||||||
|
output_type="string",
|
||||||
|
streaming_on=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.csv_filename = "ethereum_analysis.csv"
|
||||||
|
self.initialize_csv()
|
||||||
|
|
||||||
|
def get_eth_price(self) -> float:
|
||||||
|
"""Get current ETH price from CoinGecko API."""
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.coingecko.com/api/v3/simple/price",
|
||||||
|
params={"ids": "ethereum", "vs_currencies": "usd"},
|
||||||
|
)
|
||||||
|
return float(response.json()["ethereum"]["usd"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching ETH price: {str(e)}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def update_eth_price(self):
|
||||||
|
"""Update ETH price if more than 5 minutes have passed."""
|
||||||
|
if time.time() - self.last_price_update > 300: # 5 minutes
|
||||||
|
self.eth_price = self.get_eth_price()
|
||||||
|
self.last_price_update = time.time()
|
||||||
|
logger.info(f"Updated ETH price: ${self.eth_price:,.2f}")
|
||||||
|
|
||||||
|
def initialize_csv(self):
|
||||||
|
"""Initialize CSV file with headers."""
|
||||||
|
headers = [
|
||||||
|
"timestamp",
|
||||||
|
"transaction_hash",
|
||||||
|
"from_address",
|
||||||
|
"to_address",
|
||||||
|
"value_eth",
|
||||||
|
"value_usd",
|
||||||
|
"eth_price",
|
||||||
|
"gas_used",
|
||||||
|
"gas_price_gwei",
|
||||||
|
"block_number",
|
||||||
|
"analysis",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not os.path.exists(self.csv_filename):
|
||||||
|
with open(self.csv_filename, "w", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(headers)
|
||||||
|
|
||||||
|
async def analyze_transaction(
|
||||||
|
self, tx_hash: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Analyze a single transaction."""
|
||||||
|
try:
|
||||||
|
tx = self.w3.eth.get_transaction(tx_hash)
|
||||||
|
receipt = self.w3.eth.get_transaction_receipt(tx_hash)
|
||||||
|
|
||||||
|
value_eth = float(self.w3.from_wei(tx.value, "ether"))
|
||||||
|
|
||||||
|
if value_eth < self.min_value_eth:
|
||||||
|
return None
|
||||||
|
|
||||||
|
block = self.w3.eth.get_block(tx.blockNumber)
|
||||||
|
|
||||||
|
# Update ETH price if needed
|
||||||
|
self.update_eth_price()
|
||||||
|
|
||||||
|
value_usd = value_eth * self.eth_price
|
||||||
|
|
||||||
|
analysis = {
|
||||||
|
"timestamp": datetime.fromtimestamp(
|
||||||
|
block.timestamp
|
||||||
|
).isoformat(),
|
||||||
|
"transaction_hash": tx_hash.hex(),
|
||||||
|
"from_address": tx["from"],
|
||||||
|
"to_address": tx.to if tx.to else "Contract Creation",
|
||||||
|
"value_eth": value_eth,
|
||||||
|
"value_usd": value_usd,
|
||||||
|
"eth_price": self.eth_price,
|
||||||
|
"gas_used": receipt.gasUsed,
|
||||||
|
"gas_price_gwei": float(
|
||||||
|
self.w3.from_wei(tx.gasPrice, "gwei")
|
||||||
|
),
|
||||||
|
"block_number": tx.blockNumber,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's a contract
|
||||||
|
if tx.to:
|
||||||
|
code = self.w3.eth.get_code(tx.to)
|
||||||
|
analysis["is_contract"] = len(code) > 0
|
||||||
|
|
||||||
|
# Get contract events
|
||||||
|
if analysis["is_contract"]:
|
||||||
|
analysis["events"] = receipt.logs
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error analyzing transaction {tx_hash}: {str(e)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prepare_analysis_prompt(self, tx_data: Dict[str, Any]) -> str:
|
||||||
|
"""Prepare detailed analysis prompt including price context."""
|
||||||
|
value_usd = tx_data["value_usd"]
|
||||||
|
eth_price = tx_data["eth_price"]
|
||||||
|
|
||||||
|
prompt = f"""Analyze this Ethereum transaction in current market context:
|
||||||
|
|
||||||
|
Transaction Details:
|
||||||
|
- Value: {tx_data['value_eth']:.2f} ETH (${value_usd:,.2f} at current price)
|
||||||
|
- Current ETH Price: ${eth_price:,.2f}
|
||||||
|
- From: {tx_data['from_address']}
|
||||||
|
- To: {tx_data['to_address']}
|
||||||
|
- Contract Interaction: {tx_data.get('is_contract', False)}
|
||||||
|
- Gas Used: {tx_data['gas_used']:,} units
|
||||||
|
- Gas Price: {tx_data['gas_price_gwei']:.2f} Gwei
|
||||||
|
- Block: {tx_data['block_number']}
|
||||||
|
- Timestamp: {tx_data['timestamp']}
|
||||||
|
|
||||||
|
{f"Event Count: {len(tx_data['events'])} events" if tx_data.get('events') else "No contract events"}
|
||||||
|
|
||||||
|
Consider the transaction's significance given the current ETH price of ${eth_price:,.2f} and total USD value of ${value_usd:,.2f}.
|
||||||
|
Analyze market impact, patterns, risks, and strategic implications."""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def save_to_csv(self, tx_data: Dict[str, Any], ai_analysis: str):
|
||||||
|
"""Save transaction data and analysis to CSV."""
|
||||||
|
row = [
|
||||||
|
tx_data["timestamp"],
|
||||||
|
tx_data["transaction_hash"],
|
||||||
|
tx_data["from_address"],
|
||||||
|
tx_data["to_address"],
|
||||||
|
tx_data["value_eth"],
|
||||||
|
tx_data["value_usd"],
|
||||||
|
tx_data["eth_price"],
|
||||||
|
tx_data["gas_used"],
|
||||||
|
tx_data["gas_price_gwei"],
|
||||||
|
tx_data["block_number"],
|
||||||
|
ai_analysis.replace("\n", " "),
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(self.csv_filename, "a", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
async def monitor_transactions(self):
|
||||||
|
"""Monitor and analyze transactions one at a time."""
|
||||||
|
logger.info(
|
||||||
|
f"Starting transaction monitor (minimum value: {self.min_value_eth} ETH)"
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
current_block = self.w3.eth.block_number
|
||||||
|
block = self.w3.eth.get_block(
|
||||||
|
current_block, full_transactions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for tx in block.transactions:
|
||||||
|
tx_analysis = await self.analyze_transaction(
|
||||||
|
tx.hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if tx_analysis:
|
||||||
|
# Get AI analysis
|
||||||
|
analysis_prompt = (
|
||||||
|
self.prepare_analysis_prompt(tx_analysis)
|
||||||
|
)
|
||||||
|
ai_analysis = self.agent.run(analysis_prompt)
|
||||||
|
print(ai_analysis)
|
||||||
|
|
||||||
|
# Save to CSV
|
||||||
|
self.save_to_csv(tx_analysis, ai_analysis)
|
||||||
|
|
||||||
|
# Print analysis
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("New Transaction Analysis")
|
||||||
|
print(
|
||||||
|
f"Hash: {tx_analysis['transaction_hash']}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Value: {tx_analysis['value_eth']:.2f} ETH (${tx_analysis['value_usd']:,.2f})"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Current ETH Price: ${self.eth_price:,.2f}"
|
||||||
|
)
|
||||||
|
print("=" * 50)
|
||||||
|
print(ai_analysis)
|
||||||
|
print("=" * 50 + "\n")
|
||||||
|
|
||||||
|
await asyncio.sleep(1) # Wait for next block
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in monitoring loop: {str(e)}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Entry point for the analysis system."""
|
||||||
|
analyzer = EthereumAnalyzer(min_value_eth=100.0)
|
||||||
|
await analyzer.monitor_transactions()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Starting Ethereum Transaction Analyzer...")
|
||||||
|
print("Saving results to ethereum_analysis.csv")
|
||||||
|
print("Press Ctrl+C to stop")
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping analyzer...")
|
Can't render this file because it has a wrong number of fields in line 4.
|
@ -0,0 +1,292 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
from loguru import logger
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StarAttentionConfig:
|
||||||
|
"""Configuration for StarAttention module.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hidden_size: Dimension of the model's hidden states
|
||||||
|
num_attention_heads: Number of attention heads
|
||||||
|
num_hosts: Number of hosts in the distributed system
|
||||||
|
block_size: Size of each context block
|
||||||
|
anchor_size: Size of the anchor block
|
||||||
|
dropout_prob: Dropout probability (default: 0.1)
|
||||||
|
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_hosts: int
|
||||||
|
block_size: int
|
||||||
|
anchor_size: int
|
||||||
|
dropout_prob: float = 0.1
|
||||||
|
layer_norm_eps: float = 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
class StarAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of Star Attention mechanism for distributed inference.
|
||||||
|
|
||||||
|
The module implements a two-phase attention mechanism:
|
||||||
|
1. Local Context Encoding with Anchor Blocks
|
||||||
|
2. Query Encoding and Output Generation with Global Attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StarAttentionConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
||||||
|
f"heads {config.num_attention_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.head_dim = (
|
||||||
|
config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||||||
|
self.layer_norm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# KV cache for storing computed key/value pairs
|
||||||
|
self.kv_cache = {}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized StarAttention with config: {config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_heads(
|
||||||
|
self, tensor: torch.Tensor, num_heads: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Split the last dimension into (num_heads, head_dim)."""
|
||||||
|
batch_size, seq_len, _ = tensor.size()
|
||||||
|
tensor = tensor.view(
|
||||||
|
batch_size, seq_len, num_heads, self.head_dim
|
||||||
|
)
|
||||||
|
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
return tensor.transpose(1, 2)
|
||||||
|
|
||||||
|
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Merge the head dimension back into hidden_size."""
|
||||||
|
batch_size, _, seq_len, _ = tensor.size()
|
||||||
|
tensor = tensor.transpose(1, 2)
|
||||||
|
return tensor.reshape(
|
||||||
|
batch_size, seq_len, self.config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_attention_scores(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute attention scores and weighted values."""
|
||||||
|
# Scale dot-product attention
|
||||||
|
scores = torch.matmul(
|
||||||
|
query, key.transpose(-2, -1)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||||||
|
|
||||||
|
# Online softmax computation
|
||||||
|
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
||||||
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
context = torch.matmul(attention_probs, value)
|
||||||
|
|
||||||
|
return context, attention_probs
|
||||||
|
|
||||||
|
def phase1_local_context_encoding(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
host_id: int,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Phase 1: Local Context Encoding with Anchor Blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||||
|
host_id: ID of the current host
|
||||||
|
device: Device to run computations on
|
||||||
|
"""
|
||||||
|
logger.debug(f"Starting Phase 1 on host {host_id}")
|
||||||
|
|
||||||
|
# Calculate block assignments
|
||||||
|
block_start = host_id * self.config.block_size
|
||||||
|
block_end = block_start + self.config.block_size
|
||||||
|
|
||||||
|
# Get local block
|
||||||
|
local_block = input_ids[:, block_start:block_end].to(device)
|
||||||
|
|
||||||
|
# Get anchor block (first block)
|
||||||
|
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute KV pairs for local block
|
||||||
|
local_hidden = self.layer_norm(local_block)
|
||||||
|
local_key = self._split_heads(
|
||||||
|
self.key(local_hidden), self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
local_value = self._split_heads(
|
||||||
|
self.value(local_hidden), self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in KV cache
|
||||||
|
self.kv_cache[host_id] = {
|
||||||
|
"key": local_key,
|
||||||
|
"value": local_value,
|
||||||
|
"anchor_key": (
|
||||||
|
None
|
||||||
|
if host_id == 0
|
||||||
|
else self._split_heads(
|
||||||
|
self.key(self.layer_norm(anchor_block)),
|
||||||
|
self.config.num_attention_heads,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
||||||
|
f"key: {local_key.shape}, value: {local_value.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def phase2_query_encoding(
|
||||||
|
self,
|
||||||
|
query_input: torch.Tensor,
|
||||||
|
host_id: int,
|
||||||
|
is_query_host: bool,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Phase 2: Query Encoding and Output Generation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
host_id: ID of the current host
|
||||||
|
is_query_host: Whether this host is the query host
|
||||||
|
device: Device to run computations on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor if this is the query host, None otherwise
|
||||||
|
"""
|
||||||
|
logger.debug(f"Starting Phase 2 on host {host_id}")
|
||||||
|
|
||||||
|
# Transform query
|
||||||
|
query_hidden = self.layer_norm(query_input)
|
||||||
|
query = self._split_heads(
|
||||||
|
self.query(query_hidden), self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute local attention scores
|
||||||
|
local_context, local_probs = self._compute_attention_scores(
|
||||||
|
query,
|
||||||
|
self.kv_cache[host_id]["key"],
|
||||||
|
self.kv_cache[host_id]["value"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_query_host:
|
||||||
|
# Non-query hosts send their local attention statistics
|
||||||
|
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Query host aggregates attention from all hosts
|
||||||
|
all_attention_probs = [local_probs]
|
||||||
|
for src_rank in range(self.config.num_hosts - 1):
|
||||||
|
probs = torch.empty_like(local_probs)
|
||||||
|
dist.recv(probs, src=src_rank)
|
||||||
|
all_attention_probs.append(probs)
|
||||||
|
|
||||||
|
# Compute global attention
|
||||||
|
torch.mean(torch.stack(all_attention_probs), dim=0)
|
||||||
|
|
||||||
|
# Final output computation
|
||||||
|
output = self._merge_heads(local_context)
|
||||||
|
output = self.dropout(output)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
query_input: torch.Tensor,
|
||||||
|
host_id: int,
|
||||||
|
is_query_host: bool,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of the StarAttention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||||
|
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
host_id: ID of the current host
|
||||||
|
is_query_host: Whether this host is the query host
|
||||||
|
device: Device to run computations on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor if this is the query host, None otherwise
|
||||||
|
"""
|
||||||
|
# Phase 1: Local Context Encoding
|
||||||
|
self.phase1_local_context_encoding(input_ids, host_id, device)
|
||||||
|
|
||||||
|
# Phase 2: Query Encoding and Output Generation
|
||||||
|
return self.phase2_query_encoding(
|
||||||
|
query_input, host_id, is_query_host, device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Example forward pass
|
||||||
|
config = StarAttentionConfig(
|
||||||
|
hidden_size=768,
|
||||||
|
num_attention_heads=12,
|
||||||
|
num_hosts=3,
|
||||||
|
block_size=512,
|
||||||
|
anchor_size=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = StarAttention(config)
|
||||||
|
|
||||||
|
# Example input tensors
|
||||||
|
batch_size = 4
|
||||||
|
seq_len = 512
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, 1000, (batch_size, seq_len)
|
||||||
|
) # Random input IDs
|
||||||
|
query_input = torch.randn(
|
||||||
|
batch_size, seq_len, config.hidden_size
|
||||||
|
) # Random query input
|
||||||
|
|
||||||
|
# Example forward pass for query host (host_id = 2)
|
||||||
|
output = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
query_input=query_input,
|
||||||
|
host_id=2,
|
||||||
|
is_query_host=True,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(output)
|
Loading…
Reference in new issue