From 6db21fbbe5d73dc65fb353d8bf5ba923bf853a4a Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 7 Jul 2023 16:47:02 -0400 Subject: [PATCH] api updates --- api/container.py | 20 +++--- api/main.py | 157 +++++++++++++---------------------------------- api/worker.py | 9 ++- 3 files changed, 58 insertions(+), 128 deletions(-) diff --git a/api/container.py b/api/container.py index 6c591a01..7c61956a 100644 --- a/api/container.py +++ b/api/container.py @@ -1,4 +1,3 @@ - import os import re from pathlib import Path @@ -6,17 +5,16 @@ from typing import Dict, List from fastapi.templating import Jinja2Templates -from swarms.agents.workers.agents import AgentManager +from swarms import Swarms from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal -from env import settings - BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -os.chdir(BASE_DIR / os.getenv["PLAYGROUND_DIR"]) +os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR")) +api_key = os.getenv("OPENAI_API_KEY") toolsets: List[BaseToolSet] = [ Terminal(), @@ -26,7 +24,7 @@ toolsets: List[BaseToolSet] = [ ] handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} -if os.getenv["USE_GPU"]: +if os.getenv("USE_GPU") == "True": import torch from swarms.tools.main import ImageCaptioning @@ -48,14 +46,16 @@ if os.getenv["USE_GPU"]: ) handlers[FileType.IMAGE] = ImageCaptioning("cuda") -agent_manager = AgentManager.create(toolsets=toolsets) +swarms = Swarms(api_key) file_handler = FileHandler(handlers=handlers, path=BASE_DIR) templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") -uploader = StaticUploader.from_settings( - settings, path=BASE_DIR / "static", endpoint="static" +uploader = StaticUploader( + static_dir=BASE_DIR / "static", + endpoint="static", + public_url=os.getenv("PUBLIC_URL") ) -reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"] \ No newline at end of file +reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"] diff --git a/api/main.py b/api/main.py index 1d2a44ad..7c61956a 100644 --- a/api/main.py +++ b/api/main.py @@ -1,130 +1,61 @@ import os import re -from multiprocessing import Process -from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Dict, List -from typing import List, TypedDict -import uvicorn -from fastapi import FastAPI, Request, UploadFile -from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel +from swarms import Swarms +from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe -from api.container import agent_manager, file_handler, reload_dirs, templates, uploader -from api.worker import get_task_result, start_worker, task_execute +from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal -app = FastAPI() +BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR")) -app.mount("/static", StaticFiles(directory=uploader.path), name="static") +api_key = os.getenv("OPENAI_API_KEY") +toolsets: List[BaseToolSet] = [ + Terminal(), + CodeEditor(), + RequestsGet(), + ExitConversation(), +] +handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} -class ExecuteRequest(BaseModel): - session: str - prompt: str - files: List[str] +if os.getenv("USE_GPU") == "True": + import torch + from swarms.tools.main import ImageCaptioning + from swarms.tools.main import ( + ImageEditing, + InstructPix2Pix, + Text2Image, + VisualQuestionAnswering, + ) -class ExecuteResponse(TypedDict): - answer: str - files: List[str] + if torch.cuda.is_available(): + toolsets.extend( + [ + Text2Image("cuda"), + ImageEditing("cuda"), + InstructPix2Pix("cuda"), + VisualQuestionAnswering("cuda"), + ] + ) + handlers[FileType.IMAGE] = ImageCaptioning("cuda") +swarms = Swarms(api_key) -@app.get("/", response_class=HTMLResponse) -async def index(request: Request): - return templates.TemplateResponse("index.html", {"request": request}) +file_handler = FileHandler(handlers=handlers, path=BASE_DIR) +templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") -@app.get("/dashboard", response_class=HTMLResponse) -async def dashboard(request: Request): - return templates.TemplateResponse("dashboard.html", {"request": request}) +uploader = StaticUploader( + static_dir=BASE_DIR / "static", + endpoint="static", + public_url=os.getenv("PUBLIC_URL") +) - -@app.post("/upload") -async def create_upload_file(files: List[UploadFile]): - urls = [] - for file in files: - extension = "." + file.filename.split(".")[-1] - with NamedTemporaryFile(suffix=extension) as tmp_file: - tmp_file.write(file.file.read()) - tmp_file.flush() - urls.append(uploader.upload(tmp_file.name)) - return {"urls": urls} - - -@app.post("/api/execute") -async def execute(request: ExecuteRequest) -> ExecuteResponse: - query = request.prompt - files = request.files - session = request.session - - executor = agent_manager.create_executor(session) - - promptedQuery = "\n".join([file_handler.handle(file) for file in files]) - promptedQuery += query - - try: - res = executor({"input": promptedQuery}) - except Exception as e: - return {"answer": str(e), "files": []} - - files = re.findall(r"\[file://\S*\]", res["output"]) - files = [file[1:-1].split("file://")[1] for file in files] - - return { - "answer": res["output"], - "files": [uploader.upload(file) for file in files], - } - - -@app.post("/api/execute/async") -async def execute_async(request: ExecuteRequest): - query = request.prompt - files = request.files - session = request.session - - promptedQuery = "\n".join([file_handler.handle(file) for file in files]) - promptedQuery += query - - execution = task_execute.delay(session, promptedQuery) - return {"id": execution.id} - - -@app.get("/api/execute/async/{execution_id}") -async def execute_async(execution_id: str): - execution = get_task_result(execution_id) - - result = {} - if execution.status == "SUCCESS" and execution.result: - output = execution.result.get("output", "") - files = re.findall(r"\[file://\S*\]", output) - files = [file[1:-1].split("file://")[1] for file in files] - result = { - "answer": output, - "files": [uploader.upload(file) for file in files], - } - - return { - "status": execution.status, - "info": execution.info, - "result": result, - } - - -def serve(): - p = Process(target=start_worker, args=[]) - p.start() - uvicorn.run("api.main:app", host="0.0.0.0", port=os.getenv["EVAL_PORT"]) - - -def dev(): - p = Process(target=start_worker, args=[]) - p.start() - uvicorn.run( - "api.main:app", - host="0.0.0.0", - port=os.getenv["EVAL_PORT"], - reload=True, - reload_dirs=reload_dirs, - ) \ No newline at end of file +reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"] diff --git a/api/worker.py b/api/worker.py index 798d8ab8..e36bc95d 100644 --- a/api/worker.py +++ b/api/worker.py @@ -2,10 +2,9 @@ import os from celery import Celery from celery.result import AsyncResult -from api.container import agent_manager -# from env import settings +from api.container import swarms -celery_broker = os.environ["CELERY_BROKER_URL"] +celery_broker = os.getenv("CELERY_BROKER_URL", "") celery_app = Celery(__name__) @@ -22,7 +21,7 @@ celery_app.conf.update( @celery_app.task(name="task_execute", bind=True) def task_execute(self, session: str, prompt: str): - executor = agent_manager.create_executor(session, self) + executor = swarms.create_executor(session, self) response = executor({"input": prompt}) result = {"output": response["output"]} @@ -43,4 +42,4 @@ def start_worker(): "worker", "--loglevel=INFO", ] - ) \ No newline at end of file + )