|
|
|
@ -28,11 +28,11 @@ from langchain.chains.conversation.memory import ConversationBufferMemory
|
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
|
|
|
|
|
|
# Grounding DINO
|
|
|
|
|
import swarms.agents.models.groundingdino.datasets.transforms as T
|
|
|
|
|
from swarms.agents.models.groundingdino.models import build_model
|
|
|
|
|
from swarms.agents.models.groundingdino.util import box_ops
|
|
|
|
|
from swarms.agents.models.groundingdino.util.slconfig import SLConfig
|
|
|
|
|
from swarms.agents.models.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
|
|
|
|
import swarms.agents.models.groundingdino.groundingdino.datasets.transforms as T
|
|
|
|
|
from swarms.agents.models.groundingdino.groundingdino.models import build_model
|
|
|
|
|
from swarms.agents.models.groundingdino.groundingdino.util import box_ops
|
|
|
|
|
from swarms.agents.models.groundingdino.groundingdino.util.slconfig import SLConfig
|
|
|
|
|
from swarms.agents.models.groundingdino.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
|
|
|
|
|
|
|
|
|
# segment anything
|
|
|
|
|
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
|
|
|
@ -1025,7 +1025,7 @@ class Text2Box:
|
|
|
|
|
print(f"Initializing ObjectDetection to {device}")
|
|
|
|
|
self.device = device
|
|
|
|
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
|
|
|
|
self.model_checkpoint_path = os.path.join("checkpoints","groundingdino")
|
|
|
|
|
self.model_checkpoint_path = os.path.join("checkpoints","groundingdino"groundingdino.)
|
|
|
|
|
self.model_config_path = os.path.join("checkpoints","grounding_config.py")
|
|
|
|
|
self.download_parameters()
|
|
|
|
|
self.box_threshold = 0.3
|
|
|
|
@ -1033,10 +1033,10 @@ class Text2Box:
|
|
|
|
|
self.grounding = (self.load_model()).to(self.device)
|
|
|
|
|
|
|
|
|
|
def download_parameters(self):
|
|
|
|
|
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
|
|
|
|
|
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_groundingdino.swint_ogc.pth"
|
|
|
|
|
if not os.path.exists(self.model_checkpoint_path):
|
|
|
|
|
wget.download(url,out=self.model_checkpoint_path)
|
|
|
|
|
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
|
|
|
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/groundingdino.config/GroundingDINO_SwinT_OGC.py"
|
|
|
|
|
if not os.path.exists(self.model_config_path):
|
|
|
|
|
wget.download(config_url,out=self.model_config_path)
|
|
|
|
|
def load_image(self,image_path):
|
|
|
|
|