You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/api/main.py

62 lines
1.6 KiB

2 years ago
import os
import re
2 years ago
from pathlib import Path
from typing import Dict, List
2 years ago
2 years ago
from fastapi.templating import Jinja2Templates
2 years ago
2 years ago
from swarms import Swarms
from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe
2 years ago
2 years ago
from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal
2 years ago
2 years ago
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / os.getenv("PLAYGROUND_DIR"))
2 years ago
2 years ago
api_key = os.getenv("OPENAI_API_KEY")
2 years ago
2 years ago
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
2 years ago
2 years ago
if os.getenv("USE_GPU") == "True":
import torch
2 years ago
2 years ago
from swarms.tools.main import ImageCaptioning
from swarms.tools.main import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
2 years ago
2 years ago
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
2 years ago
2 years ago
swarms = Swarms(api_key)
2 years ago
2 years ago
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
2 years ago
2 years ago
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
2 years ago
2 years ago
uploader = StaticUploader(
static_dir=BASE_DIR / "static",
endpoint="static",
public_url=os.getenv("PUBLIC_URL")
)
2 years ago
2 years ago
reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"]