You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.0 KiB
44 lines
1.0 KiB
2 years ago
|
import os
|
||
2 years ago
|
|
||
2 years ago
|
from celery import Celery
|
||
|
from celery.result import AsyncResult
|
||
|
|
||
2 years ago
|
from api.olds.container import agent_manager
|
||
2 years ago
|
|
||
|
|
||
|
celery_app = Celery(__name__)
|
||
2 years ago
|
celery_app.conf.broker_url = os.environ["CELERY_BROKER_URL"]
|
||
|
celery_app.conf.result_backend = os.environ["CELERY_BROKER_URL"]
|
||
2 years ago
|
celery_app.conf.update(
|
||
|
task_track_started=True,
|
||
|
task_serializer="json",
|
||
|
accept_content=["json"], # Ignore other content
|
||
|
result_serializer="json",
|
||
|
enable_utc=True,
|
||
|
)
|
||
|
|
||
|
|
||
|
@celery_app.task(name="task_execute", bind=True)
|
||
|
def task_execute(self, session: str, prompt: str):
|
||
2 years ago
|
executor = agent_manager.create_executor(session, self)
|
||
2 years ago
|
response = executor({"input": prompt})
|
||
|
result = {"output": response["output"]}
|
||
|
|
||
|
previous = AsyncResult(self.request.id)
|
||
|
if previous and previous.info:
|
||
|
result.update(previous.info)
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
def get_task_result(task_id):
|
||
|
return AsyncResult(task_id)
|
||
|
|
||
|
|
||
|
def start_worker():
|
||
|
celery_app.worker_main(
|
||
|
[
|
||
|
"worker",
|
||
|
"--loglevel=INFO",
|
||
|
]
|
||
2 years ago
|
)
|