You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
3.2 KiB
82 lines
3.2 KiB
import os
|
|
import numpy as np
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
|
from datasets import Dataset
|
|
from evaluate import load
|
|
|
|
def prepare_datasets(train_texts, train_labels, val_texts, val_labels, tokenizer):
|
|
"""Создаёт токенизированные датасеты для обучения и валидации."""
|
|
train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
|
|
val_dataset = Dataset.from_dict({"text": val_texts, "label": val_labels})
|
|
|
|
def tokenize_function(examples):
|
|
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
|
|
|
train_dataset = train_dataset.map(tokenize_function, batched=True)
|
|
val_dataset = val_dataset.map(tokenize_function, batched=True)
|
|
|
|
train_dataset = train_dataset.remove_columns(["text"])
|
|
val_dataset = val_dataset.remove_columns(["text"])
|
|
|
|
train_dataset = train_dataset.with_format("torch")
|
|
val_dataset = val_dataset.with_format("torch")
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
def train_model(train_dataset, val_dataset, classes, model_name, output_dir="./results", num_epochs=3):
|
|
"""Обучает модель и сохраняет её в указанной директории."""
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(classes))
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=output_dir,
|
|
eval_strategy="epoch",
|
|
save_strategy="epoch",
|
|
per_device_train_batch_size=8,
|
|
per_device_eval_batch_size=8,
|
|
num_train_epochs=num_epochs,
|
|
weight_decay=0.01,
|
|
logging_dir="./logs",
|
|
logging_steps=10,
|
|
load_best_model_at_end=True,
|
|
metric_for_best_model="accuracy"
|
|
)
|
|
|
|
accuracy_metric = load("accuracy")
|
|
|
|
def compute_metrics(eval_pred):
|
|
logits, labels = eval_pred
|
|
predictions = np.argmax(logits, axis=-1)
|
|
return accuracy_metric.compute(predictions=predictions, references=labels)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=val_dataset,
|
|
tokenizer=tokenizer,
|
|
compute_metrics=compute_metrics
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model(output_dir)
|
|
print(f"Модель сохранена в директории {output_dir}")
|
|
|
|
return model, tokenizer
|
|
|
|
def load_model_and_tokenizer(model_dir, model_name):
|
|
"""Загружает сохранённую модель и токенизатор."""
|
|
if not os.path.exists(model_dir):
|
|
raise ValueError(f"Директория {model_dir} не существует.")
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
return model, tokenizer
|
|
|
|
def predict(model, tokenizer, text, classes):
|
|
"""Делает предсказание для заданного текста."""
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
|
outputs = model(**inputs)
|
|
predictions = torch.argmax(outputs.logits, dim=1).item()
|
|
return classes[predictions]
|