pull/58/head
Kye 1 year ago
parent b98bcb8b17
commit 44ba173142

@ -28,11 +28,11 @@ from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms.openai import OpenAI
# Grounding DINO
import swarms.agents.models.groundingdino.groundingdino.datasets.transforms as T
from swarms.agents.models.groundingdino.groundingdino.models import build_model
from swarms.agents.models.groundingdino.groundingdino.util import box_ops
from swarms.agents.models.groundingdino.groundingdino.util.slconfig import SLConfig
from swarms.agents.models.groundingdino.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
import swarms.agents.models.groundingdino.datasets.transforms as T
from swarms.agents.models.groundingdino.models import build_model
from swarms.agents.models.groundingdino.util import box_ops
from swarms.agents.models.groundingdino.util.slconfig import SLConfig
from swarms.agents.models.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
@ -1025,7 +1025,7 @@ class Text2Box:
print(f"Initializing ObjectDetection to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.model_checkpoint_path = os.path.join("checkpoints","groundingdino"groundingdino.)
self.model_checkpoint_path = os.path.join("checkpoints","groundingdino")
self.model_config_path = os.path.join("checkpoints","grounding_config.py")
self.download_parameters()
self.box_threshold = 0.3
@ -1033,10 +1033,10 @@ class Text2Box:
self.grounding = (self.load_model()).to(self.device)
def download_parameters(self):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_groundingdino.swint_ogc.pth"
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/groundingdino.config/GroundingDINO_SwinT_OGC.py"
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(self.model_config_path):
wget.download(config_url,out=self.model_config_path)
def load_image(self,image_path):

Loading…
Cancel
Save