pull/11/head
Kye 2 years ago
parent 2ab6415673
commit 5acdf8f04d

@ -1,61 +1,130 @@
import os import os
import re import re
from pathlib import Path from multiprocessing import Process
from typing import Dict, List from tempfile import NamedTemporaryFile
from typing import List, TypedDict
import uvicorn
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from api.container import agent_manager, file_handler, reload_dirs, templates, uploader
from api.worker import get_task_result, start_worker, task_execute
# from env import settings
app = FastAPI()
app.mount("/static", StaticFiles(directory=uploader.path), name="static")
class ExecuteRequest(BaseModel):
session: str
prompt: str
files: List[str]
class ExecuteResponse(TypedDict):
answer: str
files: List[str]
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/dashboard", response_class=HTMLResponse)
async def dashboard(request: Request):
return templates.TemplateResponse("dashboard.html", {"request": request})
@app.post("/upload")
async def create_upload_file(files: List[UploadFile]):
urls = []
for file in files:
extension = "." + file.filename.split(".")[-1]
with NamedTemporaryFile(suffix=extension) as tmp_file:
tmp_file.write(file.file.read())
tmp_file.flush()
urls.append(uploader.upload(tmp_file.name))
return {"urls": urls}
@app.post("/api/execute")
async def execute(request: ExecuteRequest) -> ExecuteResponse:
query = request.prompt
files = request.files
session = request.session
executor = agent_manager.create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
from fastapi.templating import Jinja2Templates try:
res = executor({"input": promptedQuery})
except Exception as e:
return {"answer": str(e), "files": []}
from swarms import Swarms files = re.findall(r"\[file://\S*\]", res["output"])
from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe files = [file[1:-1].split("file://")[1] for file in files]
from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal return {
"answer": res["output"],
"files": [uploader.upload(file) for file in files],
}
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @app.post("/api/execute/async")
os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR")) async def execute_async(request: ExecuteRequest):
query = request.prompt
files = request.files
session = request.session
api_key = os.getenv("OPENAI_API_KEY") promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
toolsets: List[BaseToolSet] = [ execution = task_execute.delay(session, promptedQuery)
Terminal(), return {"id": execution.id}
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if os.getenv("USE_GPU") == "True":
import torch
from swarms.tools.main import ImageCaptioning @app.get("/api/execute/async/{execution_id}")
from swarms.tools.main import ( async def execute_async(execution_id: str):
ImageEditing, execution = get_task_result(execution_id)
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available(): result = {}
toolsets.extend( if execution.status == "SUCCESS" and execution.result:
[ output = execution.result.get("output", "")
Text2Image("cuda"), files = re.findall(r"\[file://\S*\]", output)
ImageEditing("cuda"), files = [file[1:-1].split("file://")[1] for file in files]
InstructPix2Pix("cuda"), result = {
VisualQuestionAnswering("cuda"), "answer": output,
] "files": [uploader.upload(file) for file in files],
) }
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
swarms = Swarms(api_key) return {
"status": execution.status,
"info": execution.info,
"result": result,
}
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") def serve():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run("api.main:app", host="0.0.0.0", port=os.environ["EVAL_PORT"])
uploader = StaticUploader(
static_dir=BASE_DIR / "static",
endpoint="static",
public_url=os.getenv("PUBLIC_URL")
)
reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"] def dev():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=os.environ["EVAL_PORT"],
reload=True,
reload_dirs=reload_dirs,
)
Loading…
Cancel
Save