|
|
@ -19,13 +19,12 @@ from transformers import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.prompts.prebuild.multi_modal_prompts import IMAGE_PROMPT
|
|
|
|
from swarms.prompts.prebuild.multi_modal_prompts import IMAGE_PROMPT
|
|
|
|
from swarms.tools.base import tool
|
|
|
|
from swarms.tools.tool import tool
|
|
|
|
from swarms.tools.main import BaseToolSet
|
|
|
|
|
|
|
|
from swarms.utils.logger import logger
|
|
|
|
from swarms.utils.logger import logger
|
|
|
|
from swarms.utils.main import BaseHandler, get_new_image_name
|
|
|
|
from swarms.utils.main import BaseHandler, get_new_image_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskFormer(BaseToolSet):
|
|
|
|
class MaskFormer:
|
|
|
|
def __init__(self, device):
|
|
|
|
def __init__(self, device):
|
|
|
|
print("Initializing MaskFormer to %s" % device)
|
|
|
|
print("Initializing MaskFormer to %s" % device)
|
|
|
|
self.device = device
|
|
|
|
self.device = device
|
|
|
@ -61,7 +60,7 @@ class MaskFormer(BaseToolSet):
|
|
|
|
return image_mask.resize(original_image.size)
|
|
|
|
return image_mask.resize(original_image.size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageEditing(BaseToolSet):
|
|
|
|
class ImageEditing:
|
|
|
|
def __init__(self, device):
|
|
|
|
def __init__(self, device):
|
|
|
|
print("Initializing ImageEditing to %s" % device)
|
|
|
|
print("Initializing ImageEditing to %s" % device)
|
|
|
|
self.device = device
|
|
|
|
self.device = device
|
|
|
@ -116,7 +115,7 @@ class ImageEditing(BaseToolSet):
|
|
|
|
return updated_image_path
|
|
|
|
return updated_image_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InstructPix2Pix(BaseToolSet):
|
|
|
|
class InstructPix2Pix:
|
|
|
|
def __init__(self, device):
|
|
|
|
def __init__(self, device):
|
|
|
|
print("Initializing InstructPix2Pix to %s" % device)
|
|
|
|
print("Initializing InstructPix2Pix to %s" % device)
|
|
|
|
self.device = device
|
|
|
|
self.device = device
|
|
|
@ -156,7 +155,7 @@ class InstructPix2Pix(BaseToolSet):
|
|
|
|
return updated_image_path
|
|
|
|
return updated_image_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Text2Image(BaseToolSet):
|
|
|
|
class Text2Image:
|
|
|
|
def __init__(self, device):
|
|
|
|
def __init__(self, device):
|
|
|
|
print("Initializing Text2Image to %s" % device)
|
|
|
|
print("Initializing Text2Image to %s" % device)
|
|
|
|
self.device = device
|
|
|
|
self.device = device
|
|
|
@ -190,7 +189,7 @@ class Text2Image(BaseToolSet):
|
|
|
|
return image_filename
|
|
|
|
return image_filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VisualQuestionAnswering(BaseToolSet):
|
|
|
|
class VisualQuestionAnswering:
|
|
|
|
def __init__(self, device):
|
|
|
|
def __init__(self, device):
|
|
|
|
print("Initializing VisualQuestionAnswering to %s" % device)
|
|
|
|
print("Initializing VisualQuestionAnswering to %s" % device)
|
|
|
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
|
|
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
|
|
|