api updates

pull/160/head
Kye 2 years ago
parent 14d48289fe
commit f381962d91

@ -1,4 +1,3 @@
import os import os
import re import re
from pathlib import Path from pathlib import Path
@ -6,17 +5,16 @@ from typing import Dict, List
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from swarms.agents.workers.agents import AgentManager from swarms import Swarms
from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe
from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal
from env import settings
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / os.getenv["PLAYGROUND_DIR"]) os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR"))
api_key = os.getenv("OPENAI_API_KEY")
toolsets: List[BaseToolSet] = [ toolsets: List[BaseToolSet] = [
Terminal(), Terminal(),
@ -26,7 +24,7 @@ toolsets: List[BaseToolSet] = [
] ]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if os.getenv["USE_GPU"]: if os.getenv("USE_GPU") == "True":
import torch import torch
from swarms.tools.main import ImageCaptioning from swarms.tools.main import ImageCaptioning
@ -48,14 +46,16 @@ if os.getenv["USE_GPU"]:
) )
handlers[FileType.IMAGE] = ImageCaptioning("cuda") handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets) swarms = Swarms(api_key)
file_handler = FileHandler(handlers=handlers, path=BASE_DIR) file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
uploader = StaticUploader.from_settings( uploader = StaticUploader(
settings, path=BASE_DIR / "static", endpoint="static" static_dir=BASE_DIR / "static",
endpoint="static",
public_url=os.getenv("PUBLIC_URL")
) )
reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"] reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"]

@ -1,130 +1,61 @@
import os import os
import re import re
from multiprocessing import Process from pathlib import Path
from tempfile import NamedTemporaryFile from typing import Dict, List
from typing import List, TypedDict from fastapi.templating import Jinja2Templates
import uvicorn
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles from swarms import Swarms
from pydantic import BaseModel from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe
from api.container import agent_manager, file_handler, reload_dirs, templates, uploader from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal
from api.worker import get_task_result, start_worker, task_execute
app = FastAPI() BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR"))
app.mount("/static", StaticFiles(directory=uploader.path), name="static") api_key = os.getenv("OPENAI_API_KEY")
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
class ExecuteRequest(BaseModel): if os.getenv("USE_GPU") == "True":
session: str import torch
prompt: str
files: List[str]
from swarms.tools.main import ImageCaptioning
from swarms.tools.main import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
class ExecuteResponse(TypedDict): if torch.cuda.is_available():
answer: str toolsets.extend(
files: List[str] [
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
swarms = Swarms(api_key)
@app.get("/", response_class=HTMLResponse) file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
@app.get("/dashboard", response_class=HTMLResponse) uploader = StaticUploader(
async def dashboard(request: Request): static_dir=BASE_DIR / "static",
return templates.TemplateResponse("dashboard.html", {"request": request}) endpoint="static",
public_url=os.getenv("PUBLIC_URL")
)
reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"]
@app.post("/upload")
async def create_upload_file(files: List[UploadFile]):
urls = []
for file in files:
extension = "." + file.filename.split(".")[-1]
with NamedTemporaryFile(suffix=extension) as tmp_file:
tmp_file.write(file.file.read())
tmp_file.flush()
urls.append(uploader.upload(tmp_file.name))
return {"urls": urls}
@app.post("/api/execute")
async def execute(request: ExecuteRequest) -> ExecuteResponse:
query = request.prompt
files = request.files
session = request.session
executor = agent_manager.create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
try:
res = executor({"input": promptedQuery})
except Exception as e:
return {"answer": str(e), "files": []}
files = re.findall(r"\[file://\S*\]", res["output"])
files = [file[1:-1].split("file://")[1] for file in files]
return {
"answer": res["output"],
"files": [uploader.upload(file) for file in files],
}
@app.post("/api/execute/async")
async def execute_async(request: ExecuteRequest):
query = request.prompt
files = request.files
session = request.session
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
execution = task_execute.delay(session, promptedQuery)
return {"id": execution.id}
@app.get("/api/execute/async/{execution_id}")
async def execute_async(execution_id: str):
execution = get_task_result(execution_id)
result = {}
if execution.status == "SUCCESS" and execution.result:
output = execution.result.get("output", "")
files = re.findall(r"\[file://\S*\]", output)
files = [file[1:-1].split("file://")[1] for file in files]
result = {
"answer": output,
"files": [uploader.upload(file) for file in files],
}
return {
"status": execution.status,
"info": execution.info,
"result": result,
}
def serve():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run("api.main:app", host="0.0.0.0", port=os.getenv["EVAL_PORT"])
def dev():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=os.getenv["EVAL_PORT"],
reload=True,
reload_dirs=reload_dirs,
)

@ -2,10 +2,9 @@ import os
from celery import Celery from celery import Celery
from celery.result import AsyncResult from celery.result import AsyncResult
from api.container import agent_manager from api.container import swarms
# from env import settings
celery_broker = os.environ["CELERY_BROKER_URL"] celery_broker = os.getenv("CELERY_BROKER_URL", "")
celery_app = Celery(__name__) celery_app = Celery(__name__)
@ -22,7 +21,7 @@ celery_app.conf.update(
@celery_app.task(name="task_execute", bind=True) @celery_app.task(name="task_execute", bind=True)
def task_execute(self, session: str, prompt: str): def task_execute(self, session: str, prompt: str):
executor = agent_manager.create_executor(session, self) executor = swarms.create_executor(session, self)
response = executor({"input": prompt}) response = executor({"input": prompt})
result = {"output": response["output"]} result = {"output": response["output"]}
@ -43,4 +42,4 @@ def start_worker():
"worker", "worker",
"--loglevel=INFO", "--loglevel=INFO",
] ]
) )

Loading…
Cancel
Save