diff --git a/swarms/models/model_registry.py b/swarms/models/model_registry.py new file mode 100644 index 00000000..a65ca154 --- /dev/null +++ b/swarms/models/model_registry.py @@ -0,0 +1,54 @@ +import pkgutil +import inspect + + +class ModelRegistry: + """ + A registry for storing and querying models. + + Attributes: + models (dict): A dictionary of model names and corresponding model classes. + + Methods: + __init__(): Initializes the ModelRegistry object and retrieves all available models. + _get_all_models(): Retrieves all available models from the models package. + query(text): Queries the models based on the given text and returns a dictionary of matching models. + """ + + def __init__(self): + self.models = self._get_all_models() + + def _get_all_models(self): + """ + Retrieves all available models from the models package. + + Returns: + dict: A dictionary of model names and corresponding model classes. + """ + models = {} + for importer, modname, ispkg in pkgutil.iter_modules( + models.__path__ + ): + module = importer.find_module(modname).load_module( + modname + ) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj): + models[name] = obj + return models + + def query(self, text): + """ + Queries the models based on the given text and returns a dictionary of matching models. + + Args: + text (str): The text to search for in the model names. + + Returns: + dict: A dictionary of matching model names and corresponding model classes. + """ + return { + name: model + for name, model in self.models.items() + if text in name + } diff --git a/swarms/models/qwen.py b/swarms/models/qwen.py index 1533b117..5ed131bf 100644 --- a/swarms/models/qwen.py +++ b/swarms/models/qwen.py @@ -13,10 +13,20 @@ class QwenVLMultiModal(BaseMultiModalModel): QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot. It inherits from the BaseMultiModalModel class. - Examples: - >>> model = QwenVLMultiModal() - >>> model.run("Hello, how are you?", "https://example.com/image.jpg") + Args: + model_name (str): The name of the model to be used. + device (str): The device to run the model on. + args (tuple): Additional positional arguments. + kwargs (dict): Additional keyword arguments. + quantize (bool): A flag to indicate whether to quantize the model. + return_bounding_boxes (bool): A flag to indicate whether to return bounding boxes for the image. + + + Examples: + >>> qwen = QwenVLMultiModal() + >>> response = qwen.run("Hello", "https://example.com/image.jpg") + >>> print(response) """ model_name: str = "Qwen/Qwen-VL-Chat" @@ -24,6 +34,7 @@ class QwenVLMultiModal(BaseMultiModalModel): args: tuple = field(default_factory=tuple) kwargs: dict = field(default_factory=dict) quantize: bool = False + return_bounding_boxes: bool = False def __post_init__(self): """ @@ -60,19 +71,44 @@ class QwenVLMultiModal(BaseMultiModalModel): and the image associated with the response (if any). """ try: - query = self.tokenizer.from_list_format( - [ - {"image": img, "text": text}, - ] - ) - - inputs = self.tokenizer(query, return_tensors="pt") - inputs = inputs.to(self.model.device) - pred = self.model.generate(**inputs) - response = self.tokenizer.decode( - pred.cpu()[0], skip_special_tokens=False - ) - return response + if self.return_bounding_boxes: + query = self.tokenizer.from_list_format( + [ + {"image": img, "text": text}, + ] + ) + + inputs = self.tokenizer(query, return_tensors="pt") + inputs = inputs.to(self.model.device) + pred = self.model.generate(**inputs) + response = self.tokenizer.decode( + pred.cpu()[0], skip_special_tokens=False + ) + + image_bb = self.tokenizer.draw_bbox_on_latest_picture( + response + ) + + if image_bb: + image_bb.save("output.jpg") + else: + print("No bounding boxes found in the image.") + + return response, image_bb + else: + query = self.tokenizer.from_list_format( + [ + {"image": img, "text": text}, + ] + ) + + inputs = self.tokenizer(query, return_tensors="pt") + inputs = inputs.to(self.model.device) + pred = self.model.generate(**inputs) + response = self.tokenizer.decode( + pred.cpu()[0], skip_special_tokens=False + ) + return response except Exception as error: print(f"[ERROR]: [QwenVLMultiModal]: {error}")