You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/api/container.py

61 lines
1.7 KiB

2 years ago
import os
import re
from pathlib import Path
from typing import Dict, List
from fastapi.templating import Jinja2Templates
from swarms.agents.utils.AgentManager import AgentManager
from swarms.utils.main import BaseHandler, FileHandler, FileType
from swarms.tools.main import CsvToDataframe, ExitConversation, RequestsGet, CodeEditor, Terminal
2 years ago
from swarms.tools.main import BaseToolSet
2 years ago
from swarms.utils.main import StaticUploader
2 years ago
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / os.environ["PLAYGROUND_DIR"])
2 years ago
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if os.environ["USE_GPU"]:
2 years ago
import torch
# from core.handlers.image import ImageCaptioning
2 years ago
from swarms.tools.main import ImageCaptioning
from swarms.tools.main import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
2 years ago
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
uploader = StaticUploader.from_settings(
path=BASE_DIR / "static", endpoint="static"
2 years ago
)
reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"]