parent
ace9031261
commit
0200ce7461
@ -0,0 +1,3 @@
|
||||
.env
|
||||
__pycache__
|
||||
.venv
|
@ -0,0 +1,30 @@
|
||||
FROM nvidia/cuda:11.7.0-runtime-ubuntu20.04
|
||||
WORKDIR /app/
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN \
|
||||
apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get install -y python3.10 python3-pip curl && \
|
||||
curl -sSL https://install.python-poetry.org | python3 - && \
|
||||
apt-get install -y nodejs npm
|
||||
|
||||
ENV PATH "/root/.local/bin:$PATH"
|
||||
|
||||
COPY pyproject.toml .
|
||||
COPY poetry.lock .
|
||||
|
||||
COPY api/__init__.py api/__init__.py
|
||||
RUN poetry config virtualenvs.in-project true
|
||||
RUN poetry config virtualenvs.path .venv
|
||||
RUN poetry config installer.max-workers 10
|
||||
RUN poetry env use 3.10
|
||||
RUN poetry install --with tools,gpu
|
||||
|
||||
COPY . .
|
||||
|
||||
ENV PORT 8001
|
||||
|
||||
ENTRYPOINT ["poetry", "run", "serve"]
|
@ -0,0 +1,65 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from swarms.agents.workers.agents import AgentManager
|
||||
from swarms.tools.main import BaseHandler, FileHandler, FileType
|
||||
from swarms.tools.main import CsvToDataframe
|
||||
|
||||
from swarms.tools.main import BaseToolSet
|
||||
from swarms.tools.main import ExitConversation, RequestsGet
|
||||
from swarms.tools.main import CodeEditor
|
||||
|
||||
from swarms.tools.main import Terminal
|
||||
from swarms.tools.main import StaticUploader
|
||||
|
||||
|
||||
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
os.chdir(BASE_DIR / os.getenv["PLAYGROUND_DIR"])
|
||||
|
||||
|
||||
toolsets: List[BaseToolSet] = [
|
||||
Terminal(),
|
||||
CodeEditor(),
|
||||
RequestsGet(),
|
||||
ExitConversation(),
|
||||
]
|
||||
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
|
||||
|
||||
if os.getenv["USE_GPU"]:
|
||||
import torch
|
||||
|
||||
from swarms.tools.main import ImageCaptioning
|
||||
from swarms.tools.main import (
|
||||
ImageEditing,
|
||||
InstructPix2Pix,
|
||||
Text2Image,
|
||||
VisualQuestionAnswering,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
toolsets.extend(
|
||||
[
|
||||
Text2Image("cuda"),
|
||||
ImageEditing("cuda"),
|
||||
InstructPix2Pix("cuda"),
|
||||
VisualQuestionAnswering("cuda"),
|
||||
]
|
||||
)
|
||||
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
|
||||
|
||||
agent_manager = AgentManager.create(toolsets=toolsets)
|
||||
|
||||
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
|
||||
|
||||
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
|
||||
|
||||
uploader = StaticUploader.from_settings(
|
||||
settings, path=BASE_DIR / "static", endpoint="static"
|
||||
)
|
||||
|
||||
reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"]
|
@ -0,0 +1,130 @@
|
||||
import os
|
||||
import re
|
||||
from multiprocessing import Process
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from typing import List, TypedDict
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, UploadFile
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.container import agent_manager, file_handler, reload_dirs, templates, uploader
|
||||
from api.worker import get_task_result, start_worker, task_execute
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.mount("/static", StaticFiles(directory=uploader.path), name="static")
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
session: str
|
||||
prompt: str
|
||||
files: List[str]
|
||||
|
||||
|
||||
class ExecuteResponse(TypedDict):
|
||||
answer: str
|
||||
files: List[str]
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
|
||||
|
||||
@app.get("/dashboard", response_class=HTMLResponse)
|
||||
async def dashboard(request: Request):
|
||||
return templates.TemplateResponse("dashboard.html", {"request": request})
|
||||
|
||||
|
||||
@app.post("/upload")
|
||||
async def create_upload_file(files: List[UploadFile]):
|
||||
urls = []
|
||||
for file in files:
|
||||
extension = "." + file.filename.split(".")[-1]
|
||||
with NamedTemporaryFile(suffix=extension) as tmp_file:
|
||||
tmp_file.write(file.file.read())
|
||||
tmp_file.flush()
|
||||
urls.append(uploader.upload(tmp_file.name))
|
||||
return {"urls": urls}
|
||||
|
||||
|
||||
@app.post("/api/execute")
|
||||
async def execute(request: ExecuteRequest) -> ExecuteResponse:
|
||||
query = request.prompt
|
||||
files = request.files
|
||||
session = request.session
|
||||
|
||||
executor = agent_manager.create_executor(session)
|
||||
|
||||
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
|
||||
promptedQuery += query
|
||||
|
||||
try:
|
||||
res = executor({"input": promptedQuery})
|
||||
except Exception as e:
|
||||
return {"answer": str(e), "files": []}
|
||||
|
||||
files = re.findall(r"\[file://\S*\]", res["output"])
|
||||
files = [file[1:-1].split("file://")[1] for file in files]
|
||||
|
||||
return {
|
||||
"answer": res["output"],
|
||||
"files": [uploader.upload(file) for file in files],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/execute/async")
|
||||
async def execute_async(request: ExecuteRequest):
|
||||
query = request.prompt
|
||||
files = request.files
|
||||
session = request.session
|
||||
|
||||
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
|
||||
promptedQuery += query
|
||||
|
||||
execution = task_execute.delay(session, promptedQuery)
|
||||
return {"id": execution.id}
|
||||
|
||||
|
||||
@app.get("/api/execute/async/{execution_id}")
|
||||
async def execute_async(execution_id: str):
|
||||
execution = get_task_result(execution_id)
|
||||
|
||||
result = {}
|
||||
if execution.status == "SUCCESS" and execution.result:
|
||||
output = execution.result.get("output", "")
|
||||
files = re.findall(r"\[file://\S*\]", output)
|
||||
files = [file[1:-1].split("file://")[1] for file in files]
|
||||
result = {
|
||||
"answer": output,
|
||||
"files": [uploader.upload(file) for file in files],
|
||||
}
|
||||
|
||||
return {
|
||||
"status": execution.status,
|
||||
"info": execution.info,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
|
||||
def serve():
|
||||
p = Process(target=start_worker, args=[])
|
||||
p.start()
|
||||
uvicorn.run("api.main:app", host="0.0.0.0", port=os.getenv["EVAL_PORT"])
|
||||
|
||||
|
||||
def dev():
|
||||
p = Process(target=start_worker, args=[])
|
||||
p.start()
|
||||
uvicorn.run(
|
||||
"api.main:app",
|
||||
host="0.0.0.0",
|
||||
port=os.getenv["EVAL_PORT"],
|
||||
reload=True,
|
||||
reload_dirs=reload_dirs,
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
import os
|
||||
from celery import Celery
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from api.container import agent_manager
|
||||
# from env import settings
|
||||
|
||||
celery_broker = os.environ["CELERY_BROKER_URL"]
|
||||
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.conf.broker_url = celery_broker
|
||||
celery_app.conf.result_backend = celery_broker
|
||||
celery_app.conf.update(
|
||||
task_track_started=True,
|
||||
task_serializer="json",
|
||||
accept_content=["json"], # Ignore other content
|
||||
result_serializer="json",
|
||||
enable_utc=True,
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="task_execute", bind=True)
|
||||
def task_execute(self, session: str, prompt: str):
|
||||
executor = agent_manager.create_executor(session, self)
|
||||
response = executor({"input": prompt})
|
||||
result = {"output": response["output"]}
|
||||
|
||||
previous = AsyncResult(self.request.id)
|
||||
if previous and previous.info:
|
||||
result.update(previous.info)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_task_result(task_id):
|
||||
return AsyncResult(task_id)
|
||||
|
||||
|
||||
def start_worker():
|
||||
celery_app.worker_main(
|
||||
[
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
)
|
Loading…
Reference in new issue