You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
965 lines
35 KiB
965 lines
35 KiB
"""
|
|
MultiModelOptimizer: A high-performance optimizer for training multiple transformer models simultaneously.
|
|
|
|
This optimizer implements several advanced techniques:
|
|
1. Gradient accumulation with dynamic batch sizing
|
|
2. Hierarchical parameter synchronization
|
|
3. Memory-efficient gradient sharing with shape compatibility
|
|
4. Adaptive learning rate scheduling per model
|
|
5. Convergence acceleration via momentum tuning
|
|
6. Robust error handling for production environments
|
|
|
|
Author: Claude 3.7 Sonnet
|
|
License: MIT
|
|
"""
|
|
|
|
import math
|
|
from typing import Dict, List, Optional, Tuple, Callable
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from loguru import logger
|
|
import numpy as np
|
|
|
|
|
|
class MultiModelOptimizer(Optimizer):
|
|
"""
|
|
An optimizer designed for training multiple models simultaneously with shared gradient information,
|
|
adaptive learning rates, and efficient memory usage.
|
|
|
|
Args:
|
|
models (Dict[str, nn.Module]): Dictionary mapping model names to model instances
|
|
lr (float, optional): Initial learning rate. Default: 1e-3
|
|
betas (Tuple[float, float], optional): Coefficients for computing running averages of gradient and its square. Default: (0.9, 0.999)
|
|
eps (float, optional): Term added to denominator for numerical stability. Default: 1e-8
|
|
weight_decay (float, optional): Weight decay coefficient. Default: 0
|
|
amsgrad (bool, optional): Whether to use the AMSGrad variant. Default: False
|
|
grad_sync_frequency (int, optional): How often to synchronize gradients between models. Default: 1
|
|
warmup_steps (int, optional): Number of warmup steps for learning rate. Default: 1000
|
|
model_weights (Dict[str, float], optional): Relative importance weights for each model. Default: None
|
|
gradient_accumulation_steps (int, optional): Number of steps to accumulate gradients before update. Default: 1
|
|
clip_grad_norm (float, optional): Maximum norm for gradient clipping. Default: None
|
|
use_cosine_schedule (bool, optional): Whether to use cosine annealing schedule. Default: True
|
|
sync_every_step (bool, optional): Whether to synchronize parameters on every step. Default: False
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
models: Dict[str, nn.Module],
|
|
lr: float = 1e-3,
|
|
betas: Tuple[float, float] = (0.9, 0.999),
|
|
eps: float = 1e-8,
|
|
weight_decay: float = 0,
|
|
amsgrad: bool = False,
|
|
grad_sync_frequency: int = 1,
|
|
warmup_steps: int = 1000,
|
|
model_weights: Optional[Dict[str, float]] = None,
|
|
gradient_accumulation_steps: int = 1,
|
|
clip_grad_norm: Optional[float] = None,
|
|
use_cosine_schedule: bool = True,
|
|
sync_every_step: bool = False,
|
|
):
|
|
|
|
# Initialize model weights if not provided
|
|
if model_weights is None:
|
|
model_weights = {name: 1.0 for name in models.keys()}
|
|
|
|
# Normalize weights to sum to 1
|
|
total_weight = sum(model_weights.values())
|
|
self.model_weights = {
|
|
k: v / total_weight for k, v in model_weights.items()
|
|
}
|
|
|
|
# Store models
|
|
self.models = models
|
|
|
|
# Collect all parameters from all models
|
|
parameters = []
|
|
self.model_param_groups: Dict[str, List[Dict]] = {}
|
|
|
|
for model_name, model in models.items():
|
|
model_params = []
|
|
for param in model.parameters():
|
|
if param.requires_grad:
|
|
param_dict = {
|
|
"params": [param],
|
|
"model_name": model_name,
|
|
}
|
|
parameters.append(param_dict)
|
|
model_params.append(param_dict)
|
|
self.model_param_groups[model_name] = model_params
|
|
|
|
# Initialize optimizer with all parameters
|
|
defaults = dict(
|
|
lr=lr,
|
|
betas=betas,
|
|
eps=eps,
|
|
weight_decay=weight_decay,
|
|
amsgrad=amsgrad,
|
|
)
|
|
super(MultiModelOptimizer, self).__init__(
|
|
parameters, defaults
|
|
)
|
|
|
|
# Additional settings
|
|
self.grad_sync_frequency = grad_sync_frequency
|
|
self.warmup_steps = warmup_steps
|
|
self.step_count = 0
|
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
self.current_accumulation_step = 0
|
|
self.clip_grad_norm = clip_grad_norm
|
|
self.use_cosine_schedule = use_cosine_schedule
|
|
self.sync_every_step = sync_every_step
|
|
|
|
# Metrics and tracking
|
|
self.model_losses: Dict[str, List[float]] = defaultdict(list)
|
|
self.model_gradients: Dict[str, torch.Tensor] = {}
|
|
self.shared_gradient_cache: Dict[str, torch.Tensor] = {}
|
|
|
|
# Set up gradient sharing structures
|
|
self.param_name_to_model = {}
|
|
for name, model in self.models.items():
|
|
for param_name, _ in model.named_parameters():
|
|
self.param_name_to_model[f"{name}.{param_name}"] = (
|
|
name
|
|
)
|
|
|
|
# Configure logger
|
|
logger.configure(
|
|
handlers=[
|
|
{
|
|
"sink": "logs/multi_model_optimizer.log",
|
|
"level": "INFO",
|
|
},
|
|
{"sink": lambda msg: print(msg), "level": "INFO"},
|
|
]
|
|
)
|
|
|
|
logger.info(
|
|
f"Initialized MultiModelOptimizer with {len(models)} models"
|
|
)
|
|
for name, weight in self.model_weights.items():
|
|
logger.info(f"Model {name} weight: {weight:.4f}")
|
|
|
|
def get_lr_multiplier(self) -> float:
|
|
"""Calculate the learning rate multiplier based on warmup and schedule."""
|
|
if self.step_count < self.warmup_steps:
|
|
# Linear warmup
|
|
return float(self.step_count) / float(
|
|
max(1, self.warmup_steps)
|
|
)
|
|
|
|
if self.use_cosine_schedule:
|
|
# Cosine decay after warmup
|
|
decay_steps = max(1, self.step_count - self.warmup_steps)
|
|
cosine_decay = 0.5 * (
|
|
1
|
|
+ math.cos(
|
|
math.pi
|
|
* decay_steps
|
|
/ (10000 * self.gradient_accumulation_steps)
|
|
)
|
|
)
|
|
return max(
|
|
0.1, cosine_decay
|
|
) # Don't let LR go below 10% of base value
|
|
|
|
return 1.0 # Constant LR after warmup if not using cosine
|
|
|
|
def share_gradients(self):
|
|
"""Share gradient information across models for similar parameters."""
|
|
# First, collect all gradients by parameter type and shape
|
|
param_type_shape_grads = defaultdict(list)
|
|
|
|
for model_name, model in self.models.items():
|
|
for param_name, param in model.named_parameters():
|
|
if param.grad is not None:
|
|
# Classify parameter by name pattern and include shape to ensure compatibility
|
|
param_type = self._classify_parameter(param_name)
|
|
param_shape = param.shape
|
|
key = (param_type, param_shape)
|
|
param_type_shape_grads[key].append(
|
|
(model_name, param_name, param.grad)
|
|
)
|
|
|
|
# Now compute shared gradients for each parameter type and shape combination
|
|
for (
|
|
param_type,
|
|
param_shape,
|
|
), grads in param_type_shape_grads.items():
|
|
if len(grads) <= 1:
|
|
continue # Skip if only one model has this parameter type+shape
|
|
|
|
cache_key = f"{param_type}_{param_shape}"
|
|
|
|
# Compute weighted average gradient for this parameter type+shape
|
|
for model_name, param_name, grad in grads:
|
|
weight = self.model_weights[model_name]
|
|
|
|
# Initialize shared gradient for this parameter if not exists
|
|
if cache_key not in self.shared_gradient_cache:
|
|
self.shared_gradient_cache[cache_key] = (
|
|
torch.zeros_like(grad)
|
|
)
|
|
|
|
# Add weighted contribution
|
|
self.shared_gradient_cache[cache_key].add_(
|
|
grad * weight
|
|
)
|
|
|
|
# Now apply a fraction of the shared gradient back to each model's parameter
|
|
for model_name, param_name, _ in grads:
|
|
param = self.models[model_name].get_parameter(
|
|
param_name
|
|
)
|
|
if param.grad is not None:
|
|
# Mix original gradient with shared gradient
|
|
sharing_ratio = 0.2 # 20% shared, 80% original
|
|
param.grad.mul_(1 - sharing_ratio).add_(
|
|
self.shared_gradient_cache[cache_key]
|
|
* sharing_ratio
|
|
)
|
|
|
|
# Clear the cache for next iteration
|
|
self.shared_gradient_cache.clear()
|
|
|
|
def _classify_parameter(self, param_name: str) -> str:
|
|
"""Classify parameter by name to determine which parameters should share gradients."""
|
|
# First, make sure we include the model architecture in the classification
|
|
# to prevent mixing parameters from different architectures
|
|
model_type = "unknown"
|
|
if "bert" in param_name:
|
|
model_type = "bert"
|
|
elif "gpt" in param_name:
|
|
model_type = "gpt"
|
|
elif "roberta" in param_name:
|
|
model_type = "roberta"
|
|
elif "transformer" in param_name:
|
|
model_type = "transformer"
|
|
|
|
# Then classify by parameter type
|
|
param_type = "other"
|
|
if (
|
|
"query" in param_name
|
|
or "key" in param_name
|
|
or "value" in param_name
|
|
):
|
|
param_type = "attention"
|
|
elif (
|
|
"dense" in param_name
|
|
or "fc" in param_name
|
|
or "ffn" in param_name
|
|
):
|
|
param_type = "ffn"
|
|
elif "embedding" in param_name:
|
|
param_type = "embedding"
|
|
elif "norm" in param_name or "layer_norm" in param_name:
|
|
param_type = "norm"
|
|
elif "bias" in param_name:
|
|
param_type = "bias"
|
|
else:
|
|
param_type = param_name.split(".")[
|
|
-1
|
|
] # Use the last component of the name
|
|
|
|
# Combine model type and parameter type for more specific classification
|
|
return f"{model_type}_{param_type}"
|
|
|
|
def step(
|
|
self, closure: Optional[Callable[[], float]] = None
|
|
) -> Optional[float]:
|
|
"""Perform a single optimization step, handling gradient accumulation and sync."""
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
self.current_accumulation_step += 1
|
|
|
|
# Only perform the update after accumulating enough gradients
|
|
if (
|
|
self.current_accumulation_step
|
|
< self.gradient_accumulation_steps
|
|
):
|
|
return loss
|
|
|
|
self.current_accumulation_step = 0
|
|
self.step_count += 1
|
|
|
|
# Apply gradient clipping if configured
|
|
if self.clip_grad_norm is not None:
|
|
for model_name, model in self.models.items():
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), self.clip_grad_norm
|
|
)
|
|
|
|
# Share gradients between models if it's time
|
|
if self.step_count % self.grad_sync_frequency == 0:
|
|
self.share_gradients()
|
|
|
|
# Calculate the current learning rate multiplier
|
|
lr_multiplier = self.get_lr_multiplier()
|
|
|
|
# Apply optimizer update for each parameter group
|
|
for group in self.param_groups:
|
|
# Get model-specific learning rate adjustment
|
|
model_name = group["model_name"]
|
|
model_weight = self.model_weights[model_name]
|
|
|
|
# Adjust lr based on model weight and global multiplier
|
|
model_lr_multiplier = lr_multiplier * (
|
|
0.5 + 0.5 * model_weight
|
|
) # Scale between 50-150% based on weight
|
|
|
|
# Extract parameters for this group
|
|
p = group["params"][0]
|
|
if p.grad is None:
|
|
continue
|
|
|
|
# State initialization
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
state["step"] = 0
|
|
state["exp_avg"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
state["exp_avg_sq"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
if group["amsgrad"]:
|
|
state["max_exp_avg_sq"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
|
|
# Extract optimizer parameters
|
|
beta1, beta2 = group["betas"]
|
|
exp_avg, exp_avg_sq = (
|
|
state["exp_avg"],
|
|
state["exp_avg_sq"],
|
|
)
|
|
|
|
# Update step count
|
|
state["step"] += 1
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(
|
|
p.grad, p.grad, value=1 - beta2
|
|
)
|
|
|
|
# Apply AMSGrad if enabled
|
|
if group["amsgrad"]:
|
|
max_exp_avg_sq = state["max_exp_avg_sq"]
|
|
torch.maximum(
|
|
max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq
|
|
)
|
|
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
|
|
else:
|
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
|
|
|
# Bias correction
|
|
bias_correction1 = 1 - beta1 ** state["step"]
|
|
bias_correction2 = 1 - beta2 ** state["step"]
|
|
step_size = (
|
|
group["lr"]
|
|
* model_lr_multiplier
|
|
* math.sqrt(bias_correction2)
|
|
/ bias_correction1
|
|
)
|
|
|
|
# Apply weight decay if configured
|
|
if group["weight_decay"] > 0:
|
|
p.data.add_(
|
|
p.data,
|
|
alpha=-group["weight_decay"]
|
|
* group["lr"]
|
|
* model_lr_multiplier,
|
|
)
|
|
|
|
# Update parameter
|
|
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
# Synchronize parameters if configured to do so every step
|
|
if self.sync_every_step:
|
|
self.synchronize_similar_parameters()
|
|
|
|
return loss
|
|
|
|
def synchronize_similar_parameters(self):
|
|
"""Synchronize similar parameters across models to promote convergence."""
|
|
# Only sync occasionally
|
|
if self.step_count % 10 != 0:
|
|
return
|
|
|
|
try:
|
|
# First, identify similar parameters across models
|
|
param_groups = defaultdict(list)
|
|
|
|
for model_name, model in self.models.items():
|
|
for param_name, param in model.named_parameters():
|
|
# Only sync parameters of the same shape
|
|
param_type = self._classify_parameter(param_name)
|
|
param_shape = param.shape
|
|
param_groups[(param_type, param_shape)].append(
|
|
(model_name, param_name, param)
|
|
)
|
|
|
|
# For each group of similar parameters, synchronize values
|
|
for (
|
|
param_type,
|
|
param_shape,
|
|
), params in param_groups.items():
|
|
if len(params) <= 1:
|
|
continue # Skip if only one parameter in this group
|
|
|
|
# Calculate weighted average
|
|
avg_param = None
|
|
total_weight = 0.0
|
|
|
|
for model_name, _, param in params:
|
|
weight = self.model_weights[model_name]
|
|
total_weight += weight
|
|
|
|
if avg_param is None:
|
|
avg_param = param.data.clone() * weight
|
|
else:
|
|
avg_param.add_(param.data * weight)
|
|
|
|
if total_weight > 0:
|
|
avg_param.div_(total_weight)
|
|
|
|
# Mix original parameters with the average (soft sync)
|
|
sync_ratio = 0.1 # 10% shared, 90% original
|
|
for _, _, param in params:
|
|
param.data.mul_(1 - sync_ratio).add_(
|
|
avg_param * sync_ratio
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error during parameter synchronization: {e}"
|
|
)
|
|
logger.error("Skipping synchronization for this step")
|
|
|
|
def log_metrics(self, model_losses: Dict[str, float]):
|
|
"""Log training metrics and update loss tracking."""
|
|
for model_name, loss in model_losses.items():
|
|
self.model_losses[model_name].append(loss)
|
|
|
|
# Log metrics every 100 steps
|
|
if self.step_count % 100 == 0:
|
|
avg_losses = {
|
|
name: np.mean(losses[-100:])
|
|
for name, losses in self.model_losses.items()
|
|
if losses
|
|
}
|
|
current_lr = (
|
|
self.param_groups[0]["lr"] * self.get_lr_multiplier()
|
|
)
|
|
|
|
logger.info(f"Step {self.step_count}")
|
|
logger.info(f"Current learning rate: {current_lr:.6f}")
|
|
for model_name, avg_loss in avg_losses.items():
|
|
logger.info(
|
|
f"Model {model_name} - Avg loss: {avg_loss:.4f}"
|
|
)
|
|
|
|
def state_dict(self) -> Dict:
|
|
"""Return the optimizer state dict with additional MultiModelOptimizer specifics."""
|
|
state_dict = super(MultiModelOptimizer, self).state_dict()
|
|
state_dict["model_weights"] = self.model_weights
|
|
state_dict["step_count"] = self.step_count
|
|
state_dict["current_accumulation_step"] = (
|
|
self.current_accumulation_step
|
|
)
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict: Dict):
|
|
"""Load optimizer state with MultiModelOptimizer specifics."""
|
|
self.model_weights = state_dict.pop("model_weights")
|
|
self.step_count = state_dict.pop("step_count")
|
|
self.current_accumulation_step = state_dict.pop(
|
|
"current_accumulation_step"
|
|
)
|
|
super(MultiModelOptimizer, self).load_state_dict(state_dict)
|
|
|
|
|
|
class MultiModelScheduler(_LRScheduler):
|
|
"""
|
|
A learning rate scheduler designed to work with MultiModelOptimizer,
|
|
supporting different schedules for different models based on their convergence rates.
|
|
|
|
Args:
|
|
optimizer (MultiModelOptimizer): The optimizer to schedule
|
|
total_steps (int): Total number of training steps
|
|
warmup_steps (int, optional): Number of warmup steps. Default: 1000
|
|
min_lr_ratio (float, optional): Minimum learning rate as a fraction of max. Default: 0.1
|
|
model_schedule_weights (Dict[str, float], optional): Per-model schedule weights. Default: None
|
|
last_epoch (int, optional): The index of the last epoch. Default: -1
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer: MultiModelOptimizer,
|
|
total_steps: int,
|
|
warmup_steps: int = 1000,
|
|
min_lr_ratio: float = 0.1,
|
|
model_schedule_weights: Optional[Dict[str, float]] = None,
|
|
last_epoch: int = -1,
|
|
):
|
|
|
|
self.total_steps = total_steps
|
|
self.warmup_steps = warmup_steps
|
|
self.min_lr_ratio = min_lr_ratio
|
|
|
|
# Use optimizer's model weights if not provided
|
|
if model_schedule_weights is None:
|
|
self.model_schedule_weights = optimizer.model_weights
|
|
else:
|
|
self.model_schedule_weights = model_schedule_weights
|
|
|
|
self.model_convergence_rates: Dict[str, float] = {
|
|
name: 1.0 for name in self.model_schedule_weights.keys()
|
|
}
|
|
super(MultiModelScheduler, self).__init__(
|
|
optimizer, last_epoch
|
|
)
|
|
|
|
def get_lr(self):
|
|
"""Calculate learning rates for all parameter groups."""
|
|
if not self._get_lr_called_within_step:
|
|
logger.warning(
|
|
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`."
|
|
)
|
|
|
|
# Apply warmup
|
|
if self.last_epoch < self.warmup_steps:
|
|
lr_scale = float(self.last_epoch) / float(
|
|
max(1, self.warmup_steps)
|
|
)
|
|
else:
|
|
# Cosine decay after warmup
|
|
progress = float(
|
|
self.last_epoch - self.warmup_steps
|
|
) / float(max(1, self.total_steps - self.warmup_steps))
|
|
lr_scale = max(
|
|
self.min_lr_ratio,
|
|
0.5 * (1.0 + math.cos(math.pi * progress)),
|
|
)
|
|
|
|
# Apply model-specific adjustments based on convergence rates
|
|
lrs = []
|
|
for group in self.optimizer.param_groups:
|
|
model_name = group["model_name"]
|
|
# Adjust learning rate based on model convergence rate
|
|
model_lr = group["initial_lr"] * lr_scale
|
|
|
|
# Apply model-specific adjustment
|
|
if model_name in self.model_convergence_rates:
|
|
# Models with higher convergence rates get lower learning rates
|
|
conv_rate = self.model_convergence_rates[model_name]
|
|
model_lr *= max(0.5, min(1.5, 1.0 / conv_rate))
|
|
|
|
lrs.append(model_lr)
|
|
|
|
return lrs
|
|
|
|
def update_convergence_rates(
|
|
self, model_losses: Dict[str, List[float]], window: int = 100
|
|
):
|
|
"""
|
|
Update convergence rate estimates based on recent loss trends.
|
|
|
|
Args:
|
|
model_losses: Dictionary mapping model names to their loss histories
|
|
window: Number of steps to consider for convergence estimation
|
|
"""
|
|
for model_name, losses in model_losses.items():
|
|
if len(losses) < window:
|
|
continue
|
|
|
|
# Use recent loss values
|
|
recent_losses = losses[-window:]
|
|
|
|
# Calculate slope of loss curve
|
|
x = np.arange(len(recent_losses))
|
|
y = np.array(recent_losses)
|
|
|
|
# Simple linear regression to estimate convergence rate
|
|
slope, _ = np.polyfit(x, y, 1)
|
|
|
|
# Normalize slope to a convergence rate
|
|
# Negative slope is good (loss is decreasing)
|
|
norm_rate = 1.0 / (1.0 + abs(slope))
|
|
|
|
# Update with exponential moving average
|
|
old_rate = self.model_convergence_rates.get(
|
|
model_name, 1.0
|
|
)
|
|
self.model_convergence_rates[model_name] = (
|
|
0.9 * old_rate + 0.1 * norm_rate
|
|
)
|
|
|
|
# Log updated convergence rates
|
|
logger.info("Updated model convergence rates:")
|
|
for model_name, rate in self.model_convergence_rates.items():
|
|
logger.info(f" {model_name}: {rate:.4f}")
|
|
|
|
|
|
# Usage example with real dataset
|
|
def example_usage_with_real_data():
|
|
"""Example demonstrating how to use MultiModelOptimizer with real data from GLUE."""
|
|
try:
|
|
# Import required libraries
|
|
from transformers import (
|
|
BertForSequenceClassification,
|
|
GPT2ForSequenceClassification,
|
|
RobertaForSequenceClassification,
|
|
BertTokenizer,
|
|
GPT2Tokenizer,
|
|
RobertaTokenizer,
|
|
DataCollatorWithPadding,
|
|
)
|
|
from datasets import load_dataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
# Set up logging
|
|
logger.info(
|
|
"=== Starting MultiModelOptimizer example with real GLUE data ==="
|
|
)
|
|
|
|
# Load SST-2 dataset from GLUE (small sentiment classification dataset)
|
|
logger.info("Loading SST-2 dataset from GLUE...")
|
|
sst2_dataset = load_dataset("glue", "sst2")
|
|
train_dataset = sst2_dataset["train"].select(
|
|
range(1000)
|
|
) # Use only 1000 examples for quick training
|
|
|
|
# Load tokenizers
|
|
logger.info("Loading tokenizers...")
|
|
bert_tokenizer = BertTokenizer.from_pretrained(
|
|
"bert-base-uncased"
|
|
)
|
|
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
roberta_tokenizer = RobertaTokenizer.from_pretrained(
|
|
"roberta-base"
|
|
)
|
|
|
|
# Add padding token to GPT2 tokenizer (it doesn't have one by default)
|
|
if gpt2_tokenizer.pad_token is None:
|
|
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
|
|
|
|
# Tokenization functions
|
|
def tokenize_bert(examples):
|
|
return bert_tokenizer(
|
|
examples["sentence"], truncation=True, max_length=128
|
|
)
|
|
|
|
def tokenize_gpt2(examples):
|
|
return gpt2_tokenizer(
|
|
examples["sentence"], truncation=True, max_length=128
|
|
)
|
|
|
|
def tokenize_roberta(examples):
|
|
return roberta_tokenizer(
|
|
examples["sentence"], truncation=True, max_length=128
|
|
)
|
|
|
|
# Tokenize datasets for each model
|
|
logger.info("Tokenizing datasets...")
|
|
bert_dataset = train_dataset.map(tokenize_bert, batched=True)
|
|
gpt2_dataset = train_dataset.map(tokenize_gpt2, batched=True)
|
|
roberta_dataset = train_dataset.map(
|
|
tokenize_roberta, batched=True
|
|
)
|
|
|
|
# Set format for PyTorch
|
|
bert_dataset.set_format(
|
|
type="torch",
|
|
columns=["input_ids", "attention_mask", "label"],
|
|
)
|
|
gpt2_dataset.set_format(
|
|
type="torch",
|
|
columns=["input_ids", "attention_mask", "label"],
|
|
)
|
|
roberta_dataset.set_format(
|
|
type="torch",
|
|
columns=["input_ids", "attention_mask", "label"],
|
|
)
|
|
|
|
# Create data collators
|
|
bert_data_collator = DataCollatorWithPadding(
|
|
tokenizer=bert_tokenizer
|
|
)
|
|
gpt2_data_collator = DataCollatorWithPadding(
|
|
tokenizer=gpt2_tokenizer
|
|
)
|
|
roberta_data_collator = DataCollatorWithPadding(
|
|
tokenizer=roberta_tokenizer
|
|
)
|
|
|
|
# Create dataloaders
|
|
logger.info("Creating dataloaders...")
|
|
batch_size = 16
|
|
bert_dataloader = DataLoader(
|
|
bert_dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=bert_data_collator,
|
|
)
|
|
gpt2_dataloader = DataLoader(
|
|
gpt2_dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=gpt2_data_collator,
|
|
)
|
|
roberta_dataloader = DataLoader(
|
|
roberta_dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=roberta_data_collator,
|
|
)
|
|
|
|
# Load models for sequence classification
|
|
logger.info(
|
|
"Loading transformer models for sequence classification..."
|
|
)
|
|
models = {
|
|
"bert": BertForSequenceClassification.from_pretrained(
|
|
"bert-base-uncased", num_labels=2
|
|
),
|
|
"gpt2": GPT2ForSequenceClassification.from_pretrained(
|
|
"gpt2", num_labels=2
|
|
),
|
|
"roberta": RobertaForSequenceClassification.from_pretrained(
|
|
"roberta-base", num_labels=2
|
|
),
|
|
}
|
|
|
|
# Set up optimizer with different weights for each model
|
|
logger.info("Setting up MultiModelOptimizer...")
|
|
optimizer = MultiModelOptimizer(
|
|
models=models,
|
|
lr=3e-5,
|
|
betas=(0.9, 0.999),
|
|
weight_decay=0.01,
|
|
model_weights={"bert": 1.0, "gpt2": 0.7, "roberta": 1.3},
|
|
gradient_accumulation_steps=2,
|
|
clip_grad_norm=1.0,
|
|
warmup_steps=100,
|
|
grad_sync_frequency=50,
|
|
)
|
|
|
|
# Set up scheduler
|
|
scheduler = MultiModelScheduler(
|
|
optimizer=optimizer, total_steps=5000, warmup_steps=100
|
|
)
|
|
|
|
# Move models to GPU if available
|
|
device = torch.device(
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
logger.info(f"Using device: {device}")
|
|
|
|
for model_name, model in models.items():
|
|
models[model_name] = model.to(device)
|
|
|
|
# Create iterator function for each dataloader
|
|
def infinite_iterator(dataloader):
|
|
while True:
|
|
for batch in dataloader:
|
|
yield batch
|
|
|
|
bert_iter = infinite_iterator(bert_dataloader)
|
|
gpt2_iter = infinite_iterator(gpt2_dataloader)
|
|
roberta_iter = infinite_iterator(roberta_dataloader)
|
|
|
|
# Define metrics for tracking
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
total_steps = 1000 # Total training steps
|
|
eval_every = 100 # Evaluate every 100 steps
|
|
best_accuracy = {"bert": 0.0, "gpt2": 0.0, "roberta": 0.0}
|
|
|
|
logger.info(f"Starting training for {total_steps} steps...")
|
|
|
|
# Training loop
|
|
for step in range(total_steps):
|
|
# Zero gradients
|
|
optimizer.zero_grad()
|
|
|
|
losses = {}
|
|
|
|
try:
|
|
# For BERT
|
|
bert_batch = next(bert_iter)
|
|
bert_batch = {
|
|
k: v.to(device) for k, v in bert_batch.items()
|
|
}
|
|
bert_outputs = models["bert"](**bert_batch)
|
|
bert_loss = bert_outputs.loss
|
|
bert_loss.backward()
|
|
losses["bert"] = bert_loss.item()
|
|
|
|
# For GPT2
|
|
gpt2_batch = next(gpt2_iter)
|
|
gpt2_batch = {
|
|
k: v.to(device) for k, v in gpt2_batch.items()
|
|
}
|
|
gpt2_outputs = models["gpt2"](**gpt2_batch)
|
|
gpt2_loss = gpt2_outputs.loss
|
|
gpt2_loss.backward()
|
|
losses["gpt2"] = gpt2_loss.item()
|
|
|
|
# For RoBERTa
|
|
roberta_batch = next(roberta_iter)
|
|
roberta_batch = {
|
|
k: v.to(device) for k, v in roberta_batch.items()
|
|
}
|
|
roberta_outputs = models["roberta"](**roberta_batch)
|
|
roberta_loss = roberta_outputs.loss
|
|
roberta_loss.backward()
|
|
losses["roberta"] = roberta_loss.item()
|
|
|
|
# Log metrics
|
|
optimizer.log_metrics(losses)
|
|
|
|
# Step the optimizer and scheduler
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
# Update convergence rates periodically
|
|
if step % 100 == 0:
|
|
scheduler.update_convergence_rates(
|
|
optimizer.model_losses
|
|
)
|
|
|
|
# Evaluate periodically
|
|
if step > 0 and step % eval_every == 0:
|
|
logger.info(f"Evaluating at step {step}...")
|
|
|
|
# Create a small evaluation set
|
|
eval_dataset = sst2_dataset["validation"].select(
|
|
range(100)
|
|
)
|
|
|
|
for model_name, model in models.items():
|
|
model.eval()
|
|
|
|
# Tokenize evaluation data based on model type
|
|
if model_name == "bert":
|
|
tokenizer = bert_tokenizer
|
|
tokenize_fn = tokenize_bert
|
|
elif model_name == "gpt2":
|
|
tokenizer = gpt2_tokenizer
|
|
tokenize_fn = tokenize_gpt2
|
|
else: # roberta
|
|
tokenizer = roberta_tokenizer
|
|
tokenize_fn = tokenize_roberta
|
|
|
|
eval_tokenized = eval_dataset.map(
|
|
tokenize_fn, batched=True
|
|
)
|
|
eval_tokenized.set_format(
|
|
type="torch",
|
|
columns=[
|
|
"input_ids",
|
|
"attention_mask",
|
|
"label",
|
|
],
|
|
)
|
|
|
|
# Create dataloader
|
|
eval_collator = DataCollatorWithPadding(
|
|
tokenizer=tokenizer
|
|
)
|
|
eval_dataloader = DataLoader(
|
|
eval_tokenized,
|
|
batch_size=16,
|
|
collate_fn=eval_collator,
|
|
)
|
|
|
|
# Evaluate
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for batch in eval_dataloader:
|
|
batch = {
|
|
k: v.to(device)
|
|
for k, v in batch.items()
|
|
}
|
|
outputs = model(**batch)
|
|
logits = outputs.logits
|
|
preds = (
|
|
torch.argmax(logits, dim=-1)
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
labels = batch["label"].cpu().numpy()
|
|
|
|
all_preds.extend(preds)
|
|
all_labels.extend(labels)
|
|
|
|
# Calculate accuracy
|
|
accuracy = accuracy_score(
|
|
all_labels, all_preds
|
|
)
|
|
logger.info(
|
|
f"Model {model_name} - Accuracy: {accuracy:.4f}"
|
|
)
|
|
|
|
# Save best model
|
|
if accuracy > best_accuracy[model_name]:
|
|
best_accuracy[model_name] = accuracy
|
|
torch.save(
|
|
model.state_dict(),
|
|
f"best_{model_name}_model.pt",
|
|
)
|
|
logger.info(
|
|
f"Saved new best {model_name} model with accuracy {accuracy:.4f}"
|
|
)
|
|
|
|
model.train()
|
|
|
|
except RuntimeError as e:
|
|
logger.error(
|
|
f"Error during training step {step}: {e}"
|
|
)
|
|
logger.error("Skipping this step and continuing...")
|
|
optimizer.zero_grad()
|
|
continue
|
|
|
|
# Save checkpoint every 500 steps
|
|
if step > 0 and step % 500 == 0:
|
|
logger.info(f"Saving checkpoint at step {step}...")
|
|
torch.save(
|
|
{
|
|
"step": step,
|
|
"model_states": {
|
|
name: model.state_dict()
|
|
for name, model in models.items()
|
|
},
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"scheduler_state": scheduler.state_dict(),
|
|
"best_accuracy": best_accuracy,
|
|
},
|
|
f"checkpoint_step_{step}.pt",
|
|
)
|
|
|
|
# Final evaluation and results
|
|
logger.info("=== Training complete! Final results ===")
|
|
for model_name, acc in best_accuracy.items():
|
|
logger.info(f"Best {model_name} accuracy: {acc:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Fatal error in example_usage_with_real_data: {e}"
|
|
)
|
|
import traceback
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Use real data example by default
|
|
example_usage_with_real_data()
|