diff --git a/mm_agent_example.py b/mm_agent_example.py index ca1e6051..66e050c6 100644 --- a/mm_agent_example.py +++ b/mm_agent_example.py @@ -1,7 +1,7 @@ from swarms.agents import MultiModalAgent load_dict = { - "ImageCaptioning": "default_device" + "ImageCaptioning": "cuda:0" } node = MultiModalAgent(load_dict) diff --git a/swarms/agents/multi_modal_visual_agent.py b/swarms/agents/multi_modal_visual_agent.py index 4f175e22..2ea3e5c3 100644 --- a/swarms/agents/multi_modal_visual_agent.py +++ b/swarms/agents/multi_modal_visual_agent.py @@ -1478,7 +1478,7 @@ class MultiModalVisualAgent: self.models = {} for class_name, device in load_dict.items(): - self.models[class_name] = globals()[class_name]#(device=device) + self.models[class_name] = globals()[class_name](device=device) for class_name, module in globals().items(): if getattr(module, 'template_model', False):