Merge branch 'main' of https://git.djft.ru/darius-atlas/road-trffic-alert-bert-classificator
commit
89eb29b47b
@ -0,0 +1 @@
|
|||||||
|
.idea/
|
@ -1,62 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="AutoImportSettings">
|
|
||||||
<option name="autoReloadType" value="SELECTIVE" />
|
|
||||||
</component>
|
|
||||||
<component name="ChangeListManager">
|
|
||||||
<list default="true" id="7698c5e9-d70a-4da7-9cfb-9ca1fb64acf6" name="Changes" comment="" />
|
|
||||||
<option name="SHOW_DIALOG" value="false" />
|
|
||||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
||||||
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
|
||||||
<option name="LAST_RESOLUTION" value="IGNORE" />
|
|
||||||
</component>
|
|
||||||
<component name="FileTemplateManagerImpl">
|
|
||||||
<option name="RECENT_TEMPLATES">
|
|
||||||
<list>
|
|
||||||
<option value="Python Script" />
|
|
||||||
</list>
|
|
||||||
</option>
|
|
||||||
</component>
|
|
||||||
<component name="Git.Settings">
|
|
||||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
|
||||||
</component>
|
|
||||||
<component name="ProjectId" id="2pu9mnzciJDWejKjKpHVRTTGJax" />
|
|
||||||
<component name="ProjectViewState">
|
|
||||||
<option name="hideEmptyMiddlePackages" value="true" />
|
|
||||||
<option name="showLibraryContents" value="true" />
|
|
||||||
</component>
|
|
||||||
<component name="PropertiesComponent"><![CDATA[{
|
|
||||||
"keyToString": {
|
|
||||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
|
||||||
"WebServerToolWindowFactoryState": "false",
|
|
||||||
"node.js.detected.package.eslint": "true",
|
|
||||||
"node.js.detected.package.tslint": "true",
|
|
||||||
"node.js.selected.package.eslint": "(autodetect)",
|
|
||||||
"node.js.selected.package.tslint": "(autodetect)",
|
|
||||||
"vue.rearranger.settings.migration": "true"
|
|
||||||
}
|
|
||||||
}]]></component>
|
|
||||||
<component name="RecentsManager">
|
|
||||||
<key name="MoveFile.RECENT_KEYS">
|
|
||||||
<recent name="D:\projects\hack-rs-2024-alert-classificator\dataset" />
|
|
||||||
</key>
|
|
||||||
</component>
|
|
||||||
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
|
||||||
<component name="TaskManager">
|
|
||||||
<task active="true" id="Default" summary="Default task">
|
|
||||||
<changelist id="7698c5e9-d70a-4da7-9cfb-9ca1fb64acf6" name="Changes" comment="" />
|
|
||||||
<created>1733605254672</created>
|
|
||||||
<option name="number" value="Default" />
|
|
||||||
<option name="presentableId" value="Default" />
|
|
||||||
<updated>1733605254672</updated>
|
|
||||||
<workItem from="1733605255732" duration="6923000" />
|
|
||||||
</task>
|
|
||||||
<servers />
|
|
||||||
</component>
|
|
||||||
<component name="TypeScriptGeneratedFilesManager">
|
|
||||||
<option name="version" value="3" />
|
|
||||||
</component>
|
|
||||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
|
||||||
<SUITE FILE_PATH="coverage/hack_rs_2024_alert_classificator$main.coverage" NAME="main Coverage Results" MODIFIED="1733610491992" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
from dataset import examples
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
def can_launch_backend():
|
||||||
|
# Проверяем размеры train_texts и train_labels
|
||||||
|
count_text = len(examples.train_texts)
|
||||||
|
count_lb = len(examples.train_labels)
|
||||||
|
|
||||||
|
if count_text != count_lb:
|
||||||
|
print(f"Размерности данных не совпадают: {count_text} текстов и {count_lb} меток.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Размерности совпадают: {count_text} текстов и {count_lb} меток.")
|
||||||
|
|
||||||
|
# Проверяем существование моделей и токенизаторов
|
||||||
|
model_name = "DeepPavlov/rubert-base-cased"
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка при загрузке модели или токенизатора: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("Модель и токенизатор успешно загружены.")
|
||||||
|
return True
|
||||||
|
|
@ -1,9 +1,255 @@
|
|||||||
|
import uvicorn
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from queue import Queue
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
from dataset import examples
|
from dataset import examples
|
||||||
|
from datasets import Dataset
|
||||||
|
import numpy as np
|
||||||
|
from evaluate import load
|
||||||
|
from transformers import Trainer, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
model: Optional[AutoModelForSequenceClassification] = None
|
||||||
|
tokenizer: Optional[AutoTokenizer] = None
|
||||||
|
classes = examples.classes
|
||||||
|
|
||||||
|
# Флаги состояния
|
||||||
|
training_in_progress = False
|
||||||
|
training_complete = False
|
||||||
|
|
||||||
|
# Очереди для обработки запросов
|
||||||
|
queue1 = Queue()
|
||||||
|
queue2 = Queue()
|
||||||
|
current_queue = queue1
|
||||||
|
processing_queue = queue2
|
||||||
|
|
||||||
|
# Lock для переключения очередей
|
||||||
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
MODEL_DIR = "./trained_model"
|
||||||
|
MODEL_NAME = "DeepPavlov/rubert-base-cased"
|
||||||
|
|
||||||
|
# Pydantic модель для запроса предсказания
|
||||||
|
class PredictionRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_datasets(train_texts, train_labels, val_texts, val_labels, tokenizer):
|
||||||
|
"""Создаёт токенизированные датасеты для обучения и валидации."""
|
||||||
|
train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
|
||||||
|
val_dataset = Dataset.from_dict({"text": val_texts, "label": val_labels})
|
||||||
|
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(tokenize_function, batched=True)
|
||||||
|
val_dataset = val_dataset.map(tokenize_function, batched=True)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.remove_columns(["text"])
|
||||||
|
val_dataset = val_dataset.remove_columns(["text"])
|
||||||
|
|
||||||
|
train_dataset = train_dataset.with_format("torch")
|
||||||
|
val_dataset = val_dataset.with_format("torch")
|
||||||
|
|
||||||
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(train_dataset, val_dataset, classes, model_name, output_dir="./results", num_epochs=3):
|
||||||
|
"""Обучает модель и сохраняет её в указанной директории."""
|
||||||
|
global training_in_progress, training_complete, model, tokenizer
|
||||||
|
|
||||||
|
training_in_progress = True
|
||||||
|
training_complete = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(classes))
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
per_device_train_batch_size=8,
|
||||||
|
per_device_eval_batch_size=8,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
weight_decay=0.01,
|
||||||
|
logging_dir="./logs",
|
||||||
|
logging_steps=10,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="accuracy"
|
||||||
|
)
|
||||||
|
|
||||||
|
accuracy_metric = load("accuracy")
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
logits, labels = eval_pred
|
||||||
|
predictions = np.argmax(logits, axis=-1)
|
||||||
|
return accuracy_metric.compute(predictions=predictions, references=labels)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
compute_metrics=compute_metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
print(f"Модель сохранена в директории {output_dir}")
|
||||||
|
|
||||||
|
# Загрузка сохранённой модели
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(output_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
training_complete = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка при обучении модели: {e}")
|
||||||
|
finally:
|
||||||
|
training_in_progress = False
|
||||||
|
|
||||||
|
def can_launch_backend():
|
||||||
|
"""Проверяет, чтобы размерности данных совпадали и модель существует."""
|
||||||
count_text = len(examples.train_texts)
|
count_text = len(examples.train_texts)
|
||||||
count_lb = len(examples.train_labels)
|
count_lb = len(examples.train_labels)
|
||||||
|
|
||||||
print(f"{count_text} / {count_lb}")
|
if count_text != count_lb:
|
||||||
|
print(f"Размерности данных не совпадают: {count_text} текстов и {count_lb} меток.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not os.path.exists(MODEL_DIR):
|
||||||
|
print(f"Директория модели {MODEL_DIR} не существует.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""Загружает сохранённую модель и токенизатор."""
|
||||||
|
global model, tokenizer
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
print("Модель и токенизатор загружены.")
|
||||||
|
|
||||||
|
def prediction_worker():
|
||||||
|
while True:
|
||||||
|
request = processing_queue.get()
|
||||||
|
if request is None:
|
||||||
|
break # Завершение работы
|
||||||
|
text, response_future = request
|
||||||
|
try:
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
predictions = torch.argmax(outputs.logits, dim=1).item()
|
||||||
|
predicted_class = classes[predictions]
|
||||||
|
response_future.set_result(predicted_class)
|
||||||
|
except Exception as e:
|
||||||
|
response_future.set_exception(e)
|
||||||
|
|
||||||
|
# Функция переключения очередей
|
||||||
|
def flip_buffers():
|
||||||
|
global current_queue, processing_queue
|
||||||
|
with queue_lock:
|
||||||
|
current_queue, processing_queue = processing_queue, current_queue
|
||||||
|
|
||||||
|
async def process_queue():
|
||||||
|
while True:
|
||||||
|
flip_buffers()
|
||||||
|
while not processing_queue.empty():
|
||||||
|
request = processing_queue.get()
|
||||||
|
queue1.put(request) if current_queue == queue1 else queue2.put(request)
|
||||||
|
await asyncio.sleep(0.1) # Интервал между переключениями
|
||||||
|
|
||||||
|
# Запуск воркера предсказаний
|
||||||
|
prediction_thread = threading.Thread(target=prediction_worker, daemon=True)
|
||||||
|
prediction_thread.start()
|
||||||
|
|
||||||
|
# Запуск фоновой задачи для обработки очереди
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
global training_in_progress, training_complete, model, tokenizer
|
||||||
|
|
||||||
|
# Проверка размерностей данных
|
||||||
|
count_text = len(examples.train_texts)
|
||||||
|
count_lb = len(examples.train_labels)
|
||||||
|
print(f"Размерности данных: {count_text} текстов и {count_lb} меток.")
|
||||||
|
|
||||||
|
if count_text != count_lb:
|
||||||
|
print("Размерности данных не совпадают. Backend не может быть запущен.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Проверка наличия модели
|
||||||
|
if not os.path.exists(MODEL_DIR):
|
||||||
|
print("Модель не найдена. Начинается обучение модели.")
|
||||||
|
tokenizer_temp = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
train_dataset, val_dataset = prepare_datasets(
|
||||||
|
examples.train_texts,
|
||||||
|
examples.train_labels,
|
||||||
|
examples.val_texts,
|
||||||
|
examples.val_labels,
|
||||||
|
tokenizer_temp
|
||||||
|
)
|
||||||
|
|
||||||
|
training_thread = threading.Thread(
|
||||||
|
target=train_model,
|
||||||
|
args=(train_dataset, val_dataset, classes, MODEL_NAME, MODEL_DIR, 3),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
training_thread.start()
|
||||||
|
else:
|
||||||
|
print("Модель найдена. Загружается модель.")
|
||||||
|
load_model()
|
||||||
|
|
||||||
|
# Запуск фоновой задачи для обработки очереди
|
||||||
|
asyncio.create_task(process_queue())
|
||||||
|
|
||||||
|
@app.post("/predict")
|
||||||
|
async def predict_endpoint(request: PredictionRequest):
|
||||||
|
if training_in_progress:
|
||||||
|
raise HTTPException(status_code=503, detail="Модель обучается. Пожалуйста, попробуйте позже.")
|
||||||
|
if not training_complete and not model:
|
||||||
|
raise HTTPException(status_code=503, detail="Модель не загружена. Пожалуйста, попробуйте позже.")
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response_future = asyncio.Future()
|
||||||
|
|
||||||
|
# Добавление запроса в очередь
|
||||||
|
with queue_lock:
|
||||||
|
current_queue.put((request.text, response_future))
|
||||||
|
|
||||||
|
try:
|
||||||
|
predicted_class = await asyncio.wait_for(response_future, timeout=10.0)
|
||||||
|
return {"predicted_class": predicted_class}
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise HTTPException(status_code=504, detail="Время ожидания предсказания истекло.")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/status")
|
||||||
|
async def status():
|
||||||
|
if training_in_progress:
|
||||||
|
return {"status": "Обучение продолжается"}
|
||||||
|
elif training_complete:
|
||||||
|
return {"status": "Модель готова"}
|
||||||
|
elif os.path.exists(MODEL_DIR):
|
||||||
|
return {"status": "Модель загружена"}
|
||||||
|
else:
|
||||||
|
return {"status": "Модель отсутствует"}
|
||||||
|
|
||||||
print(examples.train_labels)
|
# Завершение работы
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
def shutdown_event():
|
||||||
|
# Остановить воркер предсказаний
|
||||||
|
processing_queue.put(None)
|
||||||
|
prediction_thread.join()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
||||||
|
from datasets import Dataset
|
||||||
|
from evaluate import load
|
||||||
|
|
||||||
|
def prepare_datasets(train_texts, train_labels, val_texts, val_labels, tokenizer):
|
||||||
|
"""Создаёт токенизированные датасеты для обучения и валидации."""
|
||||||
|
train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
|
||||||
|
val_dataset = Dataset.from_dict({"text": val_texts, "label": val_labels})
|
||||||
|
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(tokenize_function, batched=True)
|
||||||
|
val_dataset = val_dataset.map(tokenize_function, batched=True)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.remove_columns(["text"])
|
||||||
|
val_dataset = val_dataset.remove_columns(["text"])
|
||||||
|
|
||||||
|
train_dataset = train_dataset.with_format("torch")
|
||||||
|
val_dataset = val_dataset.with_format("torch")
|
||||||
|
|
||||||
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
def train_model(train_dataset, val_dataset, classes, model_name, output_dir="./results", num_epochs=3):
|
||||||
|
"""Обучает модель и сохраняет её в указанной директории."""
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(classes))
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
per_device_train_batch_size=8,
|
||||||
|
per_device_eval_batch_size=8,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
weight_decay=0.01,
|
||||||
|
logging_dir="./logs",
|
||||||
|
logging_steps=10,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="accuracy"
|
||||||
|
)
|
||||||
|
|
||||||
|
accuracy_metric = load("accuracy")
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
logits, labels = eval_pred
|
||||||
|
predictions = np.argmax(logits, axis=-1)
|
||||||
|
return accuracy_metric.compute(predictions=predictions, references=labels)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
compute_metrics=compute_metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
print(f"Модель сохранена в директории {output_dir}")
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(model_dir, model_name):
|
||||||
|
"""Загружает сохранённую модель и токенизатор."""
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
raise ValueError(f"Директория {model_dir} не существует.")
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def predict(model, tokenizer, text, classes):
|
||||||
|
"""Делает предсказание для заданного текста."""
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
predictions = torch.argmax(outputs.logits, dim=1).item()
|
||||||
|
return classes[predictions]
|
Loading…
Reference in new issue