parent
3ab65bf559
commit
2f412dcdc0
@ -0,0 +1,188 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||||
|
from swarm_models.openai_function_caller import OpenAIFunctionCaller
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicParser:
|
||||||
|
@staticmethod
|
||||||
|
def extract_fields(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
field_name: (field.annotation, ... if field.is_required() else None)
|
||||||
|
for field_name, field in model.model_fields.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_partial_model(model: Type[BaseModel], data: Dict[str, Any]) -> Type[BaseModel]:
|
||||||
|
fields = {
|
||||||
|
field_name: (field.annotation, ... if field.is_required() else None)
|
||||||
|
for field_name, field in model.model_fields.items()
|
||||||
|
if field_name in data
|
||||||
|
}
|
||||||
|
return create_model(f"Partial{model.__name__}", **fields)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, data: Union[str, Dict[str, Any]], model: Type[BaseModel]) -> Optional[BaseModel]:
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try full model first
|
||||||
|
with suppress(ValidationError):
|
||||||
|
return model.model_validate(data)
|
||||||
|
|
||||||
|
# Create and try partial model
|
||||||
|
partial_model = cls.create_partial_model(model, data)
|
||||||
|
with suppress(ValidationError):
|
||||||
|
return partial_model.model_validate(data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Define the Thoughts schema
|
||||||
|
class Thoughts(BaseModel):
|
||||||
|
text: str = Field(..., description="Current thoughts or observations regarding the task.")
|
||||||
|
reasoning: str = Field(..., description="Logical reasoning behind the thought process.")
|
||||||
|
plan: str = Field(..., description="A short bulleted list that conveys the immediate and long-term plan.")
|
||||||
|
criticism: str = Field(..., description="Constructive self-criticism to improve future responses.")
|
||||||
|
speak: str = Field(..., description="A concise summary of thoughts intended for the user.")
|
||||||
|
|
||||||
|
# Define the Command schema
|
||||||
|
class Command(BaseModel):
|
||||||
|
name: str = Field(..., description="Command name to execute from the provided list of commands.")
|
||||||
|
args: Dict[str, Any] = Field(..., description="Arguments required to execute the command.")
|
||||||
|
|
||||||
|
# Define the AgentResponse schema
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: Thoughts = Field(..., description="The agent's current thoughts and reasoning.")
|
||||||
|
command: Command = Field(..., description="The command to execute along with its arguments.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Define tool functions
|
||||||
|
def fluid_api_command(task: str):
|
||||||
|
"""Execute a fluid API request."""
|
||||||
|
# response = fluid_api_request(task)
|
||||||
|
print(response.model_dump_json(indent=4))
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def send_tweet_command(text: str):
|
||||||
|
"""Simulate sending a tweet."""
|
||||||
|
print(f"Tweet sent: {text}")
|
||||||
|
return {"status": "success", "message": f"Tweet sent: {text}"}
|
||||||
|
|
||||||
|
|
||||||
|
def do_nothing_command():
|
||||||
|
"""Do nothing."""
|
||||||
|
print("Doing nothing...")
|
||||||
|
return {"status": "success", "message": "No action taken."}
|
||||||
|
|
||||||
|
|
||||||
|
def task_complete_command(reason: str):
|
||||||
|
"""Mark the task as complete and provide a reason."""
|
||||||
|
print(f"Task completed: {reason}")
|
||||||
|
return {"status": "success", "message": f"Task completed: {reason}"}
|
||||||
|
|
||||||
|
|
||||||
|
# Dynamic command execution
|
||||||
|
def execute_command(name: str, args: Dict[str, Any]):
|
||||||
|
"""Dynamically execute a command based on its name and arguments."""
|
||||||
|
command_map: Dict[str, Callable] = {
|
||||||
|
"fluid_api": lambda **kwargs: fluid_api_command(task=kwargs.get("task")),
|
||||||
|
"send_tweet": lambda **kwargs: send_tweet_command(text=kwargs.get("text")),
|
||||||
|
"do_nothing": lambda **kwargs: do_nothing_command(),
|
||||||
|
"task_complete": lambda **kwargs: task_complete_command(reason=kwargs.get("reason")),
|
||||||
|
}
|
||||||
|
|
||||||
|
if name not in command_map:
|
||||||
|
raise ValueError(f"Unknown command: {name}")
|
||||||
|
|
||||||
|
# Execute the command with the provided arguments
|
||||||
|
return command_map[name](**args)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_execute_command(response: Union[str, Dict[str, Any]], base_model: Type[BaseModel] = AgentResponse) -> Any:
|
||||||
|
"""Enhanced command parser with flexible input handling"""
|
||||||
|
parsed = DynamicParser.parse(response, base_model)
|
||||||
|
if not parsed:
|
||||||
|
raise ValueError("Failed to parse response")
|
||||||
|
|
||||||
|
if hasattr(parsed, 'command'):
|
||||||
|
command_name = parsed.command.name
|
||||||
|
command_args = parsed.command.args
|
||||||
|
return execute_command(command_name, command_args)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
ainame = "AutoAgent"
|
||||||
|
userprovided = "assistant"
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = f"""
|
||||||
|
You are {ainame}, an advanced and autonomous {userprovided}.
|
||||||
|
Your role is to make decisions and complete tasks independently without seeking user assistance. Leverage your strengths as an LLM to solve tasks efficiently, adhering strictly to the commands and resources provided.
|
||||||
|
|
||||||
|
### GOALS:
|
||||||
|
1. {userprovided}
|
||||||
|
2. Execute tasks with precision and efficiency.
|
||||||
|
3. Ensure outputs are actionable and aligned with the user's objectives.
|
||||||
|
4. Continuously optimize task strategies for maximum effectiveness.
|
||||||
|
5. Maintain reliability and consistency in all responses.
|
||||||
|
|
||||||
|
### CONSTRAINTS:
|
||||||
|
1. Memory limit: ~4000 words for short-term memory. Save essential information to files immediately to avoid loss.
|
||||||
|
2. Independent decision-making: Do not rely on user assistance.
|
||||||
|
3. Exclusively use commands in double quotes (e.g., "command name").
|
||||||
|
4. Use subprocesses for commands that may take longer than a few minutes.
|
||||||
|
5. Ensure all outputs strictly adhere to the specified JSON response format.
|
||||||
|
|
||||||
|
### COMMANDS:
|
||||||
|
1. Fluid API: "fluid_api", args: "method": "<GET/POST/...>", "url": "<url>", "headers": "<headers>", "body": "<payload>"
|
||||||
|
18. Send Tweet: "send_tweet", args: "text": "<text>"
|
||||||
|
19. Do Nothing: "do_nothing", args:
|
||||||
|
20. Task Complete (Shutdown): "task_complete", args: "reason": "<reason>"
|
||||||
|
|
||||||
|
### RESOURCES:
|
||||||
|
1. Internet access for real-time information and data gathering.
|
||||||
|
2. Long-term memory management for storing critical information.
|
||||||
|
3. Access to GPT-3.5-powered Agents for delegating tasks.
|
||||||
|
4. File handling capabilities for output storage and retrieval.
|
||||||
|
|
||||||
|
### PERFORMANCE EVALUATION:
|
||||||
|
1. Continuously analyze and reflect on actions to ensure optimal task completion.
|
||||||
|
2. Self-critique decisions and strategies constructively to identify areas for improvement.
|
||||||
|
3. Ensure every command serves a clear purpose and minimizes resource usage.
|
||||||
|
4. Complete tasks in the least number of steps, balancing speed and accuracy.
|
||||||
|
|
||||||
|
### RESPONSE FORMAT:
|
||||||
|
Always respond in a strict JSON format as described below. Ensure your responses can be parsed with Python's `json.loads`:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize the OpenAIFunctionCaller
|
||||||
|
model = OpenAIFunctionCaller(
|
||||||
|
system_prompt=SYSTEM_PROMPT,
|
||||||
|
max_tokens=4000,
|
||||||
|
temperature=0.9,
|
||||||
|
base_model=AgentResponse, # Pass the Pydantic schema as the base model
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
openai_api_key=os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
user_input = (
|
||||||
|
"Analyze the provided Python code for inefficiencies, generate suggestions for improvements, "
|
||||||
|
"and provide optimized code."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = model.run(user_input)
|
||||||
|
response = parse_and_execute_command(response)
|
||||||
|
print(response)
|
@ -0,0 +1,898 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Union, Optional
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import struct
|
||||||
|
|
||||||
|
|
||||||
|
from enum import auto
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
import wave
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from loguru import logger
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
"""Configuration for the enhanced BytePredictor model."""
|
||||||
|
|
||||||
|
vocab_size: int = 256 # Standard byte range
|
||||||
|
hidden_size: int = 1024
|
||||||
|
num_layers: int = 12
|
||||||
|
num_key_value_heads: int = 8 # For multi-query attention
|
||||||
|
num_query_heads: int = 32 # More query heads than kv heads
|
||||||
|
dropout: float = 0.1
|
||||||
|
max_sequence_length: int = 8192
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
layer_norm_eps: float = 1e-5
|
||||||
|
vocab_parallel: bool = False
|
||||||
|
qk_norm: bool = True
|
||||||
|
qk_norm_scale: float = None
|
||||||
|
attention_bias: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MultiQueryAttention(nn.Module):
|
||||||
|
"""Fixed Multi-Query Attention implementation."""
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_query_heads = config.num_query_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.head_dim = config.hidden_size // config.num_query_heads
|
||||||
|
self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
config.hidden_size, config.num_query_heads * self.head_dim
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_key_value_heads * self.head_dim,
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_key_value_heads * self.head_dim,
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
config.num_query_heads * self.head_dim, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qk_norm = config.qk_norm
|
||||||
|
if self.qk_norm:
|
||||||
|
self.q_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# Project and reshape
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Reshape to [seq_len, batch, heads, head_dim]
|
||||||
|
q = q.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_query_heads,
|
||||||
|
self.head_dim,
|
||||||
|
).permute(1, 0, 2, 3)
|
||||||
|
k = k.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.head_dim,
|
||||||
|
).permute(1, 0, 2, 3)
|
||||||
|
v = v.view(
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.head_dim,
|
||||||
|
).permute(1, 0, 2, 3)
|
||||||
|
|
||||||
|
# Apply rotary embeddings
|
||||||
|
# q, k = self.rotary(q, k, seq_length)
|
||||||
|
|
||||||
|
# Apply QK normalization if enabled
|
||||||
|
if self.qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# Handle MQA head expansion
|
||||||
|
if self.num_key_value_heads != self.num_query_heads:
|
||||||
|
k = k.repeat_interleave(
|
||||||
|
self.num_query_heads // self.num_key_value_heads,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
v = v.repeat_interleave(
|
||||||
|
self.num_query_heads // self.num_key_value_heads,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
# Reshape for matmul: [batch, heads, seq_length, head_dim]
|
||||||
|
q = q.permute(1, 2, 0, 3)
|
||||||
|
k = k.permute(1, 2, 0, 3)
|
||||||
|
v = v.permute(1, 2, 0, 3)
|
||||||
|
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
output = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
# Reshape back to [batch, seq_length, hidden_size]
|
||||||
|
output = (
|
||||||
|
output.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, seq_length, -1)
|
||||||
|
)
|
||||||
|
output = self.o_proj(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedBytePredictor(nn.Module):
|
||||||
|
"""Enhanced byte prediction model with state-of-the-art features."""
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Token embeddings
|
||||||
|
self.tok_embeddings = nn.Embedding(
|
||||||
|
config.vocab_size, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer layers
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"attention": MultiQueryAttention(config),
|
||||||
|
"attention_norm": nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
),
|
||||||
|
"feed_forward": nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
4 * config.hidden_size,
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(
|
||||||
|
4 * config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"feed_forward_norm": nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for _ in range(config.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.output = nn.Linear(
|
||||||
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module: nn.Module) -> None:
|
||||||
|
"""Initialize weights with scaled normal distribution."""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Tensor of shape (batch_size, sequence_length)
|
||||||
|
attention_mask: Optional attention mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of logits with shape (batch_size, sequence_length, vocab_size)
|
||||||
|
"""
|
||||||
|
hidden_states = self.tok_embeddings(input_ids)
|
||||||
|
|
||||||
|
# Create causal mask if needed
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.triu(
|
||||||
|
torch.ones(
|
||||||
|
(input_ids.size(1), input_ids.size(1)),
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=torch.bool,
|
||||||
|
),
|
||||||
|
diagonal=1,
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.masked_fill(
|
||||||
|
attention_mask == 1, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply transformer layers
|
||||||
|
for layer in self.layers:
|
||||||
|
# Attention block
|
||||||
|
hidden_states = hidden_states + layer["attention"](
|
||||||
|
layer["attention_norm"](hidden_states), attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Feed-forward block
|
||||||
|
hidden_states = hidden_states + layer["feed_forward"](
|
||||||
|
layer["feed_forward_norm"](hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
logits = self.output(hidden_states)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
target_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute cross entropy loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token ids
|
||||||
|
target_ids: Target token ids
|
||||||
|
attention_mask: Optional attention mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss value
|
||||||
|
"""
|
||||||
|
logits = self(input_ids, attention_mask)
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
rearrange(logits, "b s v -> (b s) v"),
|
||||||
|
rearrange(target_ids, "b s -> (b s)"),
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate new tokens autoregressively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Starting sequence
|
||||||
|
max_new_tokens: Number of tokens to generate
|
||||||
|
temperature: Sampling temperature
|
||||||
|
top_k: K for top-k sampling
|
||||||
|
top_p: P for nucleus sampling
|
||||||
|
repetition_penalty: Penalty for repeating tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated sequence
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
generated = input_ids.clone()
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if generated.size(1) >= self.config.max_sequence_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits = self(generated)[:, -1, :]
|
||||||
|
|
||||||
|
# Apply temperature
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# Apply repetition penalty
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
for i in range(batch_size):
|
||||||
|
for token_id in set(generated[i].tolist()):
|
||||||
|
logits[i, token_id] /= repetition_penalty
|
||||||
|
|
||||||
|
# Apply top-k sampling
|
||||||
|
if top_k is not None:
|
||||||
|
indices_to_remove = (
|
||||||
|
logits
|
||||||
|
< torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
)
|
||||||
|
logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Apply nucleus (top-p) sampling
|
||||||
|
if top_p is not None:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(
|
||||||
|
logits, descending=True
|
||||||
|
)
|
||||||
|
cumulative_probs = torch.cumsum(
|
||||||
|
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = (
|
||||||
|
sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
)
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = torch.zeros_like(
|
||||||
|
logits, dtype=torch.bool
|
||||||
|
)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Sample next token
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
# Append to sequence
|
||||||
|
generated = torch.cat([generated, next_token], dim=1)
|
||||||
|
|
||||||
|
return generated
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
):
|
||||||
|
tensor_data = self._generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tensor_to_data(tensor_data)
|
||||||
|
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
# from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
BINARY = "binary"
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDetokenizer:
|
||||||
|
"""Utility class for converting model output bytes back to original data formats."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
|
||||||
|
"""Convert model output tensor to bytes."""
|
||||||
|
# Convert logits/probabilities to byte values
|
||||||
|
if tensor.dim() > 1:
|
||||||
|
# If we have logits, convert to byte indices
|
||||||
|
byte_indices = tensor.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
byte_indices = tensor
|
||||||
|
|
||||||
|
# Convert to Python bytes
|
||||||
|
return bytes(
|
||||||
|
byte_indices.cpu().numpy().astype(np.uint8).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_text(byte_sequence: bytes) -> str:
|
||||||
|
"""Convert bytes to text."""
|
||||||
|
try:
|
||||||
|
return byte_sequence.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Try with error handling
|
||||||
|
return byte_sequence.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_image(
|
||||||
|
byte_sequence: bytes,
|
||||||
|
mode: str = "RGB",
|
||||||
|
size: Optional[tuple] = None,
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Convert bytes to image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_sequence: Raw image bytes
|
||||||
|
mode: Image mode (RGB, RGBA, L, etc.)
|
||||||
|
size: Optional tuple of (width, height)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to load as-is first (for standard image formats)
|
||||||
|
img = Image.open(io.BytesIO(byte_sequence))
|
||||||
|
if size:
|
||||||
|
img = img.resize(size)
|
||||||
|
return img
|
||||||
|
except:
|
||||||
|
# If failed, assume raw pixel data
|
||||||
|
if not size:
|
||||||
|
# Try to determine size from byte count
|
||||||
|
pixel_count = len(byte_sequence) // len(mode)
|
||||||
|
size = (
|
||||||
|
int(np.sqrt(pixel_count)),
|
||||||
|
int(np.sqrt(pixel_count)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert raw bytes to pixel array
|
||||||
|
pixels = np.frombuffer(byte_sequence, dtype=np.uint8)
|
||||||
|
pixels = pixels.reshape((*size, len(mode)))
|
||||||
|
|
||||||
|
return Image.fromarray(pixels, mode=mode)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_audio(
|
||||||
|
byte_sequence: bytes,
|
||||||
|
sample_rate: int = 44100,
|
||||||
|
channels: int = 2,
|
||||||
|
sample_width: int = 2,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Convert bytes to audio samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_sequence: Raw audio bytes
|
||||||
|
sample_rate: Audio sample rate in Hz
|
||||||
|
channels: Number of audio channels
|
||||||
|
sample_width: Bytes per sample (1, 2, or 4)
|
||||||
|
"""
|
||||||
|
# Determine format string based on sample width
|
||||||
|
format_str = {
|
||||||
|
1: "b", # signed char
|
||||||
|
2: "h", # short
|
||||||
|
4: "i", # int
|
||||||
|
}[sample_width]
|
||||||
|
|
||||||
|
# Unpack bytes to samples
|
||||||
|
sample_count = len(byte_sequence) // (channels * sample_width)
|
||||||
|
samples = struct.unpack(
|
||||||
|
f"<{sample_count * channels}{format_str}", byte_sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape to [samples, channels]
|
||||||
|
return np.array(samples).reshape(-1, channels)
|
||||||
|
|
||||||
|
def decode_data(
|
||||||
|
self,
|
||||||
|
model_output: Union[torch.Tensor, bytes],
|
||||||
|
data_type: DataType,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[str, Image.Image, np.ndarray, bytes]:
|
||||||
|
"""Main method to decode model output to desired format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output: Either tensor from model or raw bytes
|
||||||
|
data_type: Type of data to decode to
|
||||||
|
**kwargs: Additional parameters for specific decoders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded data in specified format
|
||||||
|
"""
|
||||||
|
# Convert tensor to bytes if needed
|
||||||
|
if isinstance(model_output, torch.Tensor):
|
||||||
|
byte_sequence = self.tensor_to_bytes(model_output)
|
||||||
|
else:
|
||||||
|
byte_sequence = model_output
|
||||||
|
|
||||||
|
# Decode based on type
|
||||||
|
if data_type == DataType.TEXT:
|
||||||
|
return self.decode_text(byte_sequence)
|
||||||
|
elif data_type == DataType.IMAGE:
|
||||||
|
return self.decode_image(byte_sequence, **kwargs)
|
||||||
|
elif data_type == DataType.AUDIO:
|
||||||
|
return self.decode_audio(byte_sequence, **kwargs)
|
||||||
|
elif data_type == DataType.VIDEO:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Video decoding not yet implemented"
|
||||||
|
)
|
||||||
|
else: # BINARY
|
||||||
|
return byte_sequence
|
||||||
|
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
|
||||||
|
|
||||||
|
class Modality(Enum):
|
||||||
|
TEXT = auto()
|
||||||
|
IMAGE = auto()
|
||||||
|
AUDIO = auto()
|
||||||
|
VIDEO = auto()
|
||||||
|
BINARY = auto()
|
||||||
|
MULTIMODAL = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModalityInfo:
|
||||||
|
"""Information about detected modality."""
|
||||||
|
|
||||||
|
modality: Modality
|
||||||
|
confidence: float
|
||||||
|
metadata: Dict[str, any]
|
||||||
|
sub_modalities: Optional[List["ModalityInfo"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModalityDetector:
|
||||||
|
"""Detects data modalities from byte sequences."""
|
||||||
|
|
||||||
|
# Common file signatures (magic numbers)
|
||||||
|
SIGNATURES = {
|
||||||
|
# Images
|
||||||
|
b"\xFF\xD8\xFF": "JPEG",
|
||||||
|
b"\x89PNG\r\n\x1a\n": "PNG",
|
||||||
|
b"GIF87a": "GIF",
|
||||||
|
b"GIF89a": "GIF",
|
||||||
|
b"RIFF": "WEBP",
|
||||||
|
# Audio
|
||||||
|
b"RIFF....WAVE": "WAV",
|
||||||
|
b"ID3": "MP3",
|
||||||
|
b"\xFF\xFB": "MP3",
|
||||||
|
b"OggS": "OGG",
|
||||||
|
# Video
|
||||||
|
b"\x00\x00\x00\x18ftypmp42": "MP4",
|
||||||
|
b"\x00\x00\x00\x1Cftypav01": "MP4",
|
||||||
|
b"\x1A\x45\xDF\xA3": "WEBM",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.magic = magic.Magic(mime=True)
|
||||||
|
|
||||||
|
def _check_text_probability(self, data: bytes) -> float:
|
||||||
|
"""Estimate probability that data is text."""
|
||||||
|
# Check if data is valid UTF-8
|
||||||
|
try:
|
||||||
|
data.decode("utf-8")
|
||||||
|
# Count printable ASCII characters
|
||||||
|
printable = sum(1 for b in data if 32 <= b <= 126)
|
||||||
|
return printable / len(data)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
||||||
|
"""Check if data is a valid image and extract metadata."""
|
||||||
|
try:
|
||||||
|
with io.BytesIO(data) as bio:
|
||||||
|
img = Image.open(bio)
|
||||||
|
return True, {
|
||||||
|
"format": img.format,
|
||||||
|
"size": img.size,
|
||||||
|
"mode": img.mode,
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
return False, {}
|
||||||
|
|
||||||
|
def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
||||||
|
"""Check if data is valid audio and extract metadata."""
|
||||||
|
try:
|
||||||
|
with io.BytesIO(data) as bio:
|
||||||
|
# Try to parse as WAV
|
||||||
|
with wave.open(bio) as wav:
|
||||||
|
return True, {
|
||||||
|
"channels": wav.getnchannels(),
|
||||||
|
"sample_width": wav.getsampwidth(),
|
||||||
|
"framerate": wav.getframerate(),
|
||||||
|
"frames": wav.getnframes(),
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
# Check for other audio signatures
|
||||||
|
for sig in [b"ID3", b"\xFF\xFB", b"OggS"]:
|
||||||
|
if data.startswith(sig):
|
||||||
|
return True, {"format": "compressed_audio"}
|
||||||
|
return False, {}
|
||||||
|
|
||||||
|
def _detect_boundaries(
|
||||||
|
self, data: bytes
|
||||||
|
) -> List[Tuple[int, int, Modality]]:
|
||||||
|
"""Detect boundaries between different modalities in the data."""
|
||||||
|
boundaries = []
|
||||||
|
current_pos = 0
|
||||||
|
|
||||||
|
while current_pos < len(data):
|
||||||
|
# Look for known signatures
|
||||||
|
for sig, format_type in self.SIGNATURES.items():
|
||||||
|
if data[current_pos:].startswith(sig):
|
||||||
|
# Found a signature, determine its length
|
||||||
|
if format_type in ["JPEG", "PNG", "GIF"]:
|
||||||
|
# Find image end
|
||||||
|
try:
|
||||||
|
with io.BytesIO(
|
||||||
|
data[current_pos:]
|
||||||
|
) as bio:
|
||||||
|
img = Image.open(bio)
|
||||||
|
img.verify()
|
||||||
|
size = bio.tell()
|
||||||
|
boundaries.append(
|
||||||
|
(
|
||||||
|
current_pos,
|
||||||
|
current_pos + size,
|
||||||
|
Modality.IMAGE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_pos += size
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check for text sections
|
||||||
|
text_prob = self._check_text_probability(
|
||||||
|
data[current_pos : current_pos + 1024]
|
||||||
|
)
|
||||||
|
if text_prob > 0.8:
|
||||||
|
# Look for end of text section
|
||||||
|
end_pos = current_pos + 1
|
||||||
|
while end_pos < len(data):
|
||||||
|
if (
|
||||||
|
self._check_text_probability(
|
||||||
|
data[end_pos : end_pos + 32]
|
||||||
|
)
|
||||||
|
< 0.5
|
||||||
|
):
|
||||||
|
break
|
||||||
|
end_pos += 1
|
||||||
|
boundaries.append(
|
||||||
|
(current_pos, end_pos, Modality.TEXT)
|
||||||
|
)
|
||||||
|
current_pos = end_pos
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_pos += 1
|
||||||
|
|
||||||
|
return boundaries
|
||||||
|
|
||||||
|
def detect_modality(self, data: bytes) -> ModalityInfo:
|
||||||
|
"""Detect modality of byte sequence."""
|
||||||
|
# First check for single modality
|
||||||
|
mime_type = self.magic.from_buffer(data)
|
||||||
|
|
||||||
|
# Check text
|
||||||
|
text_prob = self._check_text_probability(data)
|
||||||
|
if text_prob > 0.9:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.TEXT,
|
||||||
|
confidence=text_prob,
|
||||||
|
metadata={"mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check image
|
||||||
|
is_image, image_meta = self._check_image_validity(data)
|
||||||
|
if is_image:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
confidence=1.0,
|
||||||
|
metadata={**image_meta, "mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check audio
|
||||||
|
is_audio, audio_meta = self._check_audio_validity(data)
|
||||||
|
if is_audio:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.AUDIO,
|
||||||
|
confidence=1.0,
|
||||||
|
metadata={**audio_meta, "mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for multimodal content
|
||||||
|
boundaries = self._detect_boundaries(data)
|
||||||
|
if len(boundaries) > 1:
|
||||||
|
sub_modalities = []
|
||||||
|
for start, end, modality in boundaries:
|
||||||
|
chunk_data = data[start:end]
|
||||||
|
sub_info = self.detect_modality(chunk_data)
|
||||||
|
if sub_info.modality != Modality.BINARY:
|
||||||
|
sub_modalities.append(sub_info)
|
||||||
|
|
||||||
|
if sub_modalities:
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.MULTIMODAL,
|
||||||
|
confidence=0.8,
|
||||||
|
metadata={"mime_type": "multipart/mixed"},
|
||||||
|
sub_modalities=sub_modalities,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default to binary
|
||||||
|
return ModalityInfo(
|
||||||
|
modality=Modality.BINARY,
|
||||||
|
confidence=0.5,
|
||||||
|
metadata={"mime_type": mime_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_modalities(
|
||||||
|
self, data: bytes
|
||||||
|
) -> List[Tuple[Modality, bytes, Dict]]:
|
||||||
|
"""Split multimodal data into separate modalities."""
|
||||||
|
boundaries = self._detect_boundaries(data)
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for start, end, modality in boundaries:
|
||||||
|
chunk = data[start:end]
|
||||||
|
info = self.detect_modality(chunk)
|
||||||
|
result.append((modality, chunk, info.metadata))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDetectBytesDecoder:
|
||||||
|
"""Decoder that automatically detects and decodes different modalities."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.detector = ModalityDetector()
|
||||||
|
self.text_decoder = ByteDetokenizer() # From previous example
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, data: bytes
|
||||||
|
) -> Union[str, Image.Image, np.ndarray, List[any]]:
|
||||||
|
"""Automatically detect and decode byte sequence."""
|
||||||
|
info = self.detector.detect_modality(data)
|
||||||
|
|
||||||
|
if info.modality == Modality.MULTIMODAL:
|
||||||
|
# Handle multimodal content
|
||||||
|
parts = self.detector.split_modalities(data)
|
||||||
|
return [
|
||||||
|
self.decode(chunk) for modality, chunk, _ in parts
|
||||||
|
]
|
||||||
|
|
||||||
|
if info.modality == Modality.TEXT:
|
||||||
|
return self.text_decoder.decode_text(data)
|
||||||
|
elif info.modality == Modality.IMAGE:
|
||||||
|
return self.text_decoder.decode_image(data)
|
||||||
|
elif info.modality == Modality.AUDIO:
|
||||||
|
return self.text_decoder.decode_audio(data)
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# def demo_auto_detection():
|
||||||
|
# """Demonstrate auto modality detection."""
|
||||||
|
# # Create mixed content
|
||||||
|
# text = "Hello, World!".encode('utf-8')
|
||||||
|
|
||||||
|
# # Create a small test image
|
||||||
|
# img = Image.new('RGB', (100, 100), color='red')
|
||||||
|
# img_bytes = io.BytesIO()
|
||||||
|
# img.save(img_bytes, format='PNG')
|
||||||
|
|
||||||
|
# # Combine into multimodal content
|
||||||
|
# mixed_content = text + img_bytes.getvalue()
|
||||||
|
|
||||||
|
# # Initialize decoder
|
||||||
|
# decoder = AutoDetectBytesDecoder()
|
||||||
|
|
||||||
|
# # Decode
|
||||||
|
# result = decoder.decode(mixed_content)
|
||||||
|
|
||||||
|
# if isinstance(result, list):
|
||||||
|
# print("Detected multimodal content:")
|
||||||
|
# for i, part in enumerate(result):
|
||||||
|
# print(f"Part {i+1}: {type(part)}")
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# demo_auto_detection()
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_data(tensor: Tensor):
|
||||||
|
byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor)
|
||||||
|
|
||||||
|
# Initialize auto-detector
|
||||||
|
decoder = AutoDetectBytesDecoder()
|
||||||
|
|
||||||
|
# Decode with automatic detection
|
||||||
|
result = decoder.decode(byte_sequence)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def demo_byte_predictor():
|
||||||
|
"""Demo with smaller dimensions to test."""
|
||||||
|
# Initialize model configuration with adjusted dimensions
|
||||||
|
config = ModelConfig(
|
||||||
|
vocab_size=256,
|
||||||
|
hidden_size=128, # Smaller for testing
|
||||||
|
num_layers=2, # Fewer layers for testing
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_query_heads=4,
|
||||||
|
dropout=0.1,
|
||||||
|
max_sequence_length=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = EnhancedBytePredictor(config)
|
||||||
|
logger.info("Model initialized")
|
||||||
|
|
||||||
|
# Move to GPU if available
|
||||||
|
device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
model = model.to(device)
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Create sample input data
|
||||||
|
batch_size = 2
|
||||||
|
seq_length = 16 # Shorter sequence for testing
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_length), device=device
|
||||||
|
)
|
||||||
|
logger.info(f"Created input tensor of shape: {input_ids.shape}")
|
||||||
|
|
||||||
|
# Test forward pass
|
||||||
|
try:
|
||||||
|
logits = model(input_ids)
|
||||||
|
logger.info(
|
||||||
|
f"Forward pass successful! Output shape: {logits.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test loss computation
|
||||||
|
target_ids = torch.randint(
|
||||||
|
0,
|
||||||
|
config.vocab_size,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
loss = model.compute_loss(input_ids, target_ids)
|
||||||
|
logger.info(
|
||||||
|
f"Loss computation successful! Loss value: {loss.item():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test generation
|
||||||
|
prompt = torch.randint(
|
||||||
|
0,
|
||||||
|
config.vocab_size,
|
||||||
|
(1, 4), # Very short prompt for testing
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
generated = model.generate(
|
||||||
|
prompt, max_new_tokens=8, temperature=0.8, top_k=50
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Generation successful! Generated shape: {generated.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during execution: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set up logging
|
||||||
|
# logger.remove() # Remove default handler
|
||||||
|
# logger.add(sys.stderr, format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
||||||
|
|
||||||
|
demo_byte_predictor()
|
@ -0,0 +1,73 @@
|
|||||||
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
import pygame
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITTS:
|
||||||
|
"""
|
||||||
|
A class to interact with OpenAI API and play the generated audio with improved streaming capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"), *args, **kwargs
|
||||||
|
)
|
||||||
|
pygame.init()
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, task: str, play_sound: bool = True, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run a task with the OpenAI API and optionally play the generated audio with improved streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be executed.
|
||||||
|
play_sound (bool): If True, play the generated audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.audio.speech.create(
|
||||||
|
model="tts-1",
|
||||||
|
voice="nova",
|
||||||
|
input=task,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
audio_url = response["url"]
|
||||||
|
logger.info("Task completed successfully.")
|
||||||
|
|
||||||
|
if play_sound:
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, suffix=".mp3"
|
||||||
|
) as tmp_file:
|
||||||
|
with requests.get(audio_url, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
tmp_file.write(chunk)
|
||||||
|
pygame.mixer.music.load(tmp_file.name)
|
||||||
|
pygame.mixer.music.play()
|
||||||
|
while pygame.mixer.music.get_busy():
|
||||||
|
pygame.time.Clock().tick(10)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during task execution: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# client = OpenAITTS(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
# client.run("Hello world! This is a streaming test.", play_sound=True)
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_speech(
|
||||||
|
task: str, play_sound: bool = True, *args, **kwargs
|
||||||
|
):
|
||||||
|
out = OpenAITTS().run(
|
||||||
|
task, play_sound=play_sound, *args, **kwargs
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# print(text_to_speech(task="hello"))
|
Loading…
Reference in new issue