You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.6 KiB

#
# LLM Worker #1
# REQUEST --> RESPONSE
#
import os
import asyncio
import json
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
from datetime import datetime
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
cache_dir = "./models"
model_id = "NousResearch/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
cache_dir=cache_dir,
torch_dtype=torch.float32, # Использование float32 для CPU
device_map=None # Явное указание, что модель не будет использовать GPU
)
model.to("cpu")
async def process_message(message):
inputs = tokenizer(message['text'], return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50, # Максимальное количество генерируемых токенов
do_sample=True, # Включение сэмплирования
top_p=0.95, # Параметр nucleus sampling
top_k=50 # Параметр top-k sampling
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
processed_message = {
"answer": generated_text,
"track_uuid": message["track_uuid"],
"processed_timestamp": datetime.utcnow().isoformat()
}
return processed_message
async def start_consumer_producer():
consumer = AIOKafkaConsumer(
'request_llm_topic',
bootstrap_servers='kafka:9092',
group_id="processing_group"
)
producer = AIOKafkaProducer(bootstrap_servers='kafka:9092')
await consumer.start()
await producer.start()
try:
async for msg in consumer:
message = json.loads(msg.value.decode('utf-8'))
# Обработка сообщения
processed_message = await process_message(message)
# Сериализация обработанного сообщения
processed_message_json = json.dumps(processed_message).encode('utf-8')
# Отправка обработанного сообщения в новый топик response_llm_topic
await producer.send_and_wait("response_llm_topic", processed_message_json)
print(f"Processed and sent message with UUID: {processed_message['track_uuid']}")
finally:
await consumer.stop()
await producer.stop()
async def main():
await start_consumer_producer()
if __name__ == "__main__":
asyncio.run(main())