You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
9.4 KiB

import uvicorn
import os
import asyncio
import threading
from queue import Queue
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dataset import examples
from datasets import Dataset
import numpy as np
from evaluate import load
from transformers import Trainer, TrainingArguments
app = FastAPI()
model: Optional[AutoModelForSequenceClassification] = None
tokenizer: Optional[AutoTokenizer] = None
classes = examples.classes
# Флаги состояния
training_in_progress = False
training_complete = False
# Очереди для обработки запросов
queue1 = Queue()
queue2 = Queue()
current_queue = queue1
processing_queue = queue2
# Lock для переключения очередей
queue_lock = threading.Lock()
MODEL_DIR = "./trained_model"
MODEL_NAME = "DeepPavlov/rubert-base-cased"
# Pydantic модель для запроса предсказания
class PredictionRequest(BaseModel):
text: str
def prepare_datasets(train_texts, train_labels, val_texts, val_labels, tokenizer):
"""Создаёт токенизированные датасеты для обучения и валидации."""
train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
val_dataset = Dataset.from_dict({"text": val_texts, "label": val_labels})
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
train_dataset = train_dataset.remove_columns(["text"])
val_dataset = val_dataset.remove_columns(["text"])
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")
return train_dataset, val_dataset
def train_model(train_dataset, val_dataset, classes, model_name, output_dir="./results", num_epochs=3):
"""Обучает модель и сохраняет её в указанной директории."""
global training_in_progress, training_complete, model, tokenizer
training_in_progress = True
training_complete = False
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(classes))
training_args = TrainingArguments(
output_dir=output_dir,
eval_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=num_epochs,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy"
)
accuracy_metric = load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return accuracy_metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
trainer.save_model(output_dir)
print(f"Модель сохранена в директории {output_dir}")
# Загрузка сохранённой модели
model = AutoModelForSequenceClassification.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name)
training_complete = True
except Exception as e:
print(f"Ошибка при обучении модели: {e}")
finally:
training_in_progress = False
def can_launch_backend():
"""Проверяет, чтобы размерности данных совпадали и модель существует."""
count_text = len(examples.train_texts)
count_lb = len(examples.train_labels)
if count_text != count_lb:
print(f"Размерности данных не совпадают: {count_text} текстов и {count_lb} меток.")
return False
if not os.path.exists(MODEL_DIR):
print(f"Директория модели {MODEL_DIR} не существует.")
return False
return True
def load_model():
"""Загружает сохранённую модель и токенизатор."""
global model, tokenizer
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Модель и токенизатор загружены.")
def prediction_worker():
while True:
request = processing_queue.get()
if request is None:
break # Завершение работы
text, response_future = request
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1).item()
predicted_class = classes[predictions]
response_future.set_result(predicted_class)
except Exception as e:
response_future.set_exception(e)
# Функция переключения очередей
def flip_buffers():
global current_queue, processing_queue
with queue_lock:
current_queue, processing_queue = processing_queue, current_queue
async def process_queue():
while True:
flip_buffers()
while not processing_queue.empty():
request = processing_queue.get()
queue1.put(request) if current_queue == queue1 else queue2.put(request)
await asyncio.sleep(0.1) # Интервал между переключениями
# Запуск воркера предсказаний
prediction_thread = threading.Thread(target=prediction_worker, daemon=True)
prediction_thread.start()
# Запуск фоновой задачи для обработки очереди
@app.on_event("startup")
async def startup_event():
global training_in_progress, training_complete, model, tokenizer
# Проверка размерностей данных
count_text = len(examples.train_texts)
count_lb = len(examples.train_labels)
print(f"Размерности данных: {count_text} текстов и {count_lb} меток.")
if count_text != count_lb:
print("Размерности данных не совпадают. Backend не может быть запущен.")
return
# Проверка наличия модели
if not os.path.exists(MODEL_DIR):
print("Модель не найдена. Начинается обучение модели.")
tokenizer_temp = AutoTokenizer.from_pretrained(MODEL_NAME)
train_dataset, val_dataset = prepare_datasets(
examples.train_texts,
examples.train_labels,
examples.val_texts,
examples.val_labels,
tokenizer_temp
)
training_thread = threading.Thread(
target=train_model,
args=(train_dataset, val_dataset, classes, MODEL_NAME, MODEL_DIR, 3),
daemon=True
)
training_thread.start()
else:
print("Модель найдена. Загружается модель.")
load_model()
# Запуск фоновой задачи для обработки очереди
asyncio.create_task(process_queue())
@app.post("/predict")
async def predict_endpoint(request: PredictionRequest):
if training_in_progress:
raise HTTPException(status_code=503, detail="Модель обучается. Пожалуйста, попробуйте позже.")
if not training_complete and not model:
raise HTTPException(status_code=503, detail="Модель не загружена. Пожалуйста, попробуйте позже.")
loop = asyncio.get_event_loop()
response_future = asyncio.Future()
# Добавление запроса в очередь
with queue_lock:
current_queue.put((request.text, response_future))
try:
predicted_class = await asyncio.wait_for(response_future, timeout=10.0)
return {"predicted_class": predicted_class}
except asyncio.TimeoutError:
raise HTTPException(status_code=504, detail="Время ожидания предсказания истекло.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def status():
if training_in_progress:
return {"status": "Обучение продолжается"}
elif training_complete:
return {"status": "Модель готова"}
elif os.path.exists(MODEL_DIR):
return {"status": "Модель загружена"}
else:
return {"status": "Модель отсутствует"}
# Завершение работы
@app.on_event("shutdown")
def shutdown_event():
# Остановить воркер предсказаний
processing_queue.put(None)
prediction_thread.join()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)