You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/api/olds/main.py

130 lines
3.5 KiB

2 years ago
import os
2 years ago
2 years ago
import re
2 years ago
from multiprocessing import Process
from tempfile import NamedTemporaryFile
from typing import List, TypedDict
import uvicorn
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
2 years ago
from api.olds.container import agent_manager, file_handler, reload_dirs, templates, uploader
from api.olds.worker import get_task_result, start_worker, task_execute
2 years ago
# from env import settings
app = FastAPI()
app.mount("/static", StaticFiles(directory=uploader.path), name="static")
class ExecuteRequest(BaseModel):
session: str
prompt: str
files: List[str]
class ExecuteResponse(TypedDict):
answer: str
files: List[str]
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/dashboard", response_class=HTMLResponse)
async def dashboard(request: Request):
return templates.TemplateResponse("dashboard.html", {"request": request})
@app.post("/upload")
async def create_upload_file(files: List[UploadFile]):
urls = []
for file in files:
extension = "." + file.filename.split(".")[-1]
with NamedTemporaryFile(suffix=extension) as tmp_file:
tmp_file.write(file.file.read())
tmp_file.flush()
urls.append(uploader.upload(tmp_file.name))
return {"urls": urls}
@app.post("/api/execute")
async def execute(request: ExecuteRequest) -> ExecuteResponse:
query = request.prompt
files = request.files
session = request.session
executor = agent_manager.create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
2 years ago
2 years ago
try:
res = executor({"input": promptedQuery})
except Exception as e:
return {"answer": str(e), "files": []}
2 years ago
2 years ago
files = re.findall(r"\[file://\S*\]", res["output"])
files = [file[1:-1].split("file://")[1] for file in files]
2 years ago
2 years ago
return {
"answer": res["output"],
"files": [uploader.upload(file) for file in files],
}
2 years ago
2 years ago
@app.post("/api/execute/async")
async def execute_async(request: ExecuteRequest):
query = request.prompt
files = request.files
session = request.session
2 years ago
2 years ago
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
2 years ago
2 years ago
execution = task_execute.delay(session, promptedQuery)
return {"id": execution.id}
2 years ago
2 years ago
@app.get("/api/execute/async/{execution_id}")
async def execute_async(execution_id: str):
execution = get_task_result(execution_id)
2 years ago
2 years ago
result = {}
if execution.status == "SUCCESS" and execution.result:
output = execution.result.get("output", "")
files = re.findall(r"\[file://\S*\]", output)
files = [file[1:-1].split("file://")[1] for file in files]
result = {
"answer": output,
"files": [uploader.upload(file) for file in files],
}
2 years ago
2 years ago
return {
"status": execution.status,
"info": execution.info,
"result": result,
}
2 years ago
2 years ago
def serve():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run("api.main:app", host="0.0.0.0", port=os.environ["EVAL_PORT"])
2 years ago
2 years ago
def dev():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=os.environ["EVAL_PORT"],
reload=True,
reload_dirs=reload_dirs,
)