diff --git a/api/main.py b/api/main.py index 7c61956a..0f54a674 100644 --- a/api/main.py +++ b/api/main.py @@ -1,61 +1,130 @@ import os + import re -from pathlib import Path -from typing import Dict, List +from multiprocessing import Process +from tempfile import NamedTemporaryFile +from typing import List, TypedDict + +import uvicorn +from fastapi import FastAPI, Request, UploadFile +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel + +from api.container import agent_manager, file_handler, reload_dirs, templates, uploader +from api.worker import get_task_result, start_worker, task_execute +# from env import settings + +app = FastAPI() + +app.mount("/static", StaticFiles(directory=uploader.path), name="static") + + +class ExecuteRequest(BaseModel): + session: str + prompt: str + files: List[str] + + +class ExecuteResponse(TypedDict): + answer: str + files: List[str] + + +@app.get("/", response_class=HTMLResponse) +async def index(request: Request): + return templates.TemplateResponse("index.html", {"request": request}) + + +@app.get("/dashboard", response_class=HTMLResponse) +async def dashboard(request: Request): + return templates.TemplateResponse("dashboard.html", {"request": request}) + + +@app.post("/upload") +async def create_upload_file(files: List[UploadFile]): + urls = [] + for file in files: + extension = "." + file.filename.split(".")[-1] + with NamedTemporaryFile(suffix=extension) as tmp_file: + tmp_file.write(file.file.read()) + tmp_file.flush() + urls.append(uploader.upload(tmp_file.name)) + return {"urls": urls} + + +@app.post("/api/execute") +async def execute(request: ExecuteRequest) -> ExecuteResponse: + query = request.prompt + files = request.files + session = request.session + + executor = agent_manager.create_executor(session) + + promptedQuery = "\n".join([file_handler.handle(file) for file in files]) + promptedQuery += query -from fastapi.templating import Jinja2Templates + try: + res = executor({"input": promptedQuery}) + except Exception as e: + return {"answer": str(e), "files": []} -from swarms import Swarms -from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe + files = re.findall(r"\[file://\S*\]", res["output"]) + files = [file[1:-1].split("file://")[1] for file in files] -from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal + return { + "answer": res["output"], + "files": [uploader.upload(file) for file in files], + } -BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR")) +@app.post("/api/execute/async") +async def execute_async(request: ExecuteRequest): + query = request.prompt + files = request.files + session = request.session -api_key = os.getenv("OPENAI_API_KEY") + promptedQuery = "\n".join([file_handler.handle(file) for file in files]) + promptedQuery += query -toolsets: List[BaseToolSet] = [ - Terminal(), - CodeEditor(), - RequestsGet(), - ExitConversation(), -] -handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} + execution = task_execute.delay(session, promptedQuery) + return {"id": execution.id} -if os.getenv("USE_GPU") == "True": - import torch - from swarms.tools.main import ImageCaptioning - from swarms.tools.main import ( - ImageEditing, - InstructPix2Pix, - Text2Image, - VisualQuestionAnswering, - ) +@app.get("/api/execute/async/{execution_id}") +async def execute_async(execution_id: str): + execution = get_task_result(execution_id) - if torch.cuda.is_available(): - toolsets.extend( - [ - Text2Image("cuda"), - ImageEditing("cuda"), - InstructPix2Pix("cuda"), - VisualQuestionAnswering("cuda"), - ] - ) - handlers[FileType.IMAGE] = ImageCaptioning("cuda") + result = {} + if execution.status == "SUCCESS" and execution.result: + output = execution.result.get("output", "") + files = re.findall(r"\[file://\S*\]", output) + files = [file[1:-1].split("file://")[1] for file in files] + result = { + "answer": output, + "files": [uploader.upload(file) for file in files], + } -swarms = Swarms(api_key) + return { + "status": execution.status, + "info": execution.info, + "result": result, + } -file_handler = FileHandler(handlers=handlers, path=BASE_DIR) -templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") +def serve(): + p = Process(target=start_worker, args=[]) + p.start() + uvicorn.run("api.main:app", host="0.0.0.0", port=os.environ["EVAL_PORT"]) -uploader = StaticUploader( - static_dir=BASE_DIR / "static", - endpoint="static", - public_url=os.getenv("PUBLIC_URL") -) -reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"] +def dev(): + p = Process(target=start_worker, args=[]) + p.start() + uvicorn.run( + "api.main:app", + host="0.0.0.0", + port=os.environ["EVAL_PORT"], + reload=True, + reload_dirs=reload_dirs, + ) \ No newline at end of file