You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.6 KiB
85 lines
2.6 KiB
#
|
|
# LLM Worker #1
|
|
# REQUEST --> RESPONSE
|
|
#
|
|
|
|
|
|
import os
|
|
import asyncio
|
|
import json
|
|
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
|
|
from datetime import datetime
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
cache_dir = "./models"
|
|
model_id = "NousResearch/Meta-Llama-3.1-8B"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
cache_dir=cache_dir,
|
|
torch_dtype=torch.float32, # Использование float32 для CPU
|
|
device_map=None # Явное указание, что модель не будет использовать GPU
|
|
)
|
|
|
|
model.to("cpu")
|
|
|
|
async def process_message(message):
|
|
inputs = tokenizer(message['text'], return_tensors="pt").to("cpu")
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=50, # Максимальное количество генерируемых токенов
|
|
do_sample=True, # Включение сэмплирования
|
|
top_p=0.95, # Параметр nucleus sampling
|
|
top_k=50 # Параметр top-k sampling
|
|
)
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
processed_message = {
|
|
"answer": generated_text,
|
|
"track_uuid": message["track_uuid"],
|
|
"processed_timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
return processed_message
|
|
|
|
|
|
async def start_consumer_producer():
|
|
consumer = AIOKafkaConsumer(
|
|
'request_llm_topic',
|
|
bootstrap_servers='kafka:9092',
|
|
group_id="processing_group"
|
|
)
|
|
|
|
producer = AIOKafkaProducer(bootstrap_servers='kafka:9092')
|
|
|
|
await consumer.start()
|
|
await producer.start()
|
|
|
|
try:
|
|
async for msg in consumer:
|
|
message = json.loads(msg.value.decode('utf-8'))
|
|
|
|
# Обработка сообщения
|
|
processed_message = await process_message(message)
|
|
|
|
# Сериализация обработанного сообщения
|
|
processed_message_json = json.dumps(processed_message).encode('utf-8')
|
|
|
|
# Отправка обработанного сообщения в новый топик response_llm_topic
|
|
await producer.send_and_wait("response_llm_topic", processed_message_json)
|
|
print(f"Processed and sent message with UUID: {processed_message['track_uuid']}")
|
|
|
|
finally:
|
|
await consumer.stop()
|
|
await producer.stop()
|
|
|
|
|
|
async def main():
|
|
await start_consumer_producer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |