groundingdino

Former-commit-id: b98bcb8b17
group-chat
Kye 1 year ago
parent eeeda3f88c
commit 044ea82e42

@ -28,11 +28,11 @@ from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
# Grounding DINO # Grounding DINO
import swarms.agents.models.groundingdino.datasets.transforms as T import swarms.agents.models.groundingdino.groundingdino.datasets.transforms as T
from swarms.agents.models.groundingdino.models import build_model from swarms.agents.models.groundingdino.groundingdino.models import build_model
from swarms.agents.models.groundingdino.util import box_ops from swarms.agents.models.groundingdino.groundingdino.util import box_ops
from swarms.agents.models.groundingdino.util.slconfig import SLConfig from swarms.agents.models.groundingdino.groundingdino.util.slconfig import SLConfig
from swarms.agents.models.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap from swarms.agents.models.groundingdino.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything # segment anything
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
@ -1025,7 +1025,7 @@ class Text2Box:
print(f"Initializing ObjectDetection to {device}") print(f"Initializing ObjectDetection to {device}")
self.device = device self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.model_checkpoint_path = os.path.join("checkpoints","groundingdino") self.model_checkpoint_path = os.path.join("checkpoints","groundingdino"groundingdino.)
self.model_config_path = os.path.join("checkpoints","grounding_config.py") self.model_config_path = os.path.join("checkpoints","grounding_config.py")
self.download_parameters() self.download_parameters()
self.box_threshold = 0.3 self.box_threshold = 0.3
@ -1033,10 +1033,10 @@ class Text2Box:
self.grounding = (self.load_model()).to(self.device) self.grounding = (self.load_model()).to(self.device)
def download_parameters(self): def download_parameters(self):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_groundingdino.swint_ogc.pth"
if not os.path.exists(self.model_checkpoint_path): if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path) wget.download(url,out=self.model_checkpoint_path)
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py" config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/groundingdino.config/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(self.model_config_path): if not os.path.exists(self.model_config_path):
wget.download(config_url,out=self.model_config_path) wget.download(config_url,out=self.model_config_path)
def load_image(self,image_path): def load_image(self,image_path):

Loading…
Cancel
Save