You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/api/container.py

61 lines
1.7 KiB

2 years ago
import os
import re
from pathlib import Path
from typing import Dict, List
from fastapi.templating import Jinja2Templates
2 years ago
from swarms.agents.utils.manager import AgentManager
from swarms.utils.utils import BaseHandler, FileHandler, FileType
from swarms.tools.main import CsvToDataframe, ExitConversation, RequestsGet, CodeEditor, Terminal
2 years ago
2 years ago
from swarms.tools.main import BaseToolSet
2 years ago
2 years ago
from swarms.utils.utils import StaticUploader
2 years ago
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2 years ago
os.chdir(BASE_DIR / os.environ["PLAYGROUND_DIR"])
2 years ago
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
2 years ago
if os.environ["USE_GPU"]:
2 years ago
import torch
2 years ago
# from core.handlers.image import ImageCaptioning
2 years ago
from swarms.tools.main import ImageCaptioning
from swarms.tools.main import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
2 years ago
agent_manager = AgentManager.create(toolsets=toolsets)
2 years ago
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
2 years ago
uploader = StaticUploader.from_settings(
path=BASE_DIR / "static", endpoint="static"
2 years ago
)
2 years ago
reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"]