[REFACTOR][QwenVLMultiModal]

pull/378/head^2
Kye 11 months ago
parent 586dd6bec2
commit c4f496a5bd

@ -0,0 +1,54 @@
import pkgutil
import inspect
class ModelRegistry:
"""
A registry for storing and querying models.
Attributes:
models (dict): A dictionary of model names and corresponding model classes.
Methods:
__init__(): Initializes the ModelRegistry object and retrieves all available models.
_get_all_models(): Retrieves all available models from the models package.
query(text): Queries the models based on the given text and returns a dictionary of matching models.
"""
def __init__(self):
self.models = self._get_all_models()
def _get_all_models(self):
"""
Retrieves all available models from the models package.
Returns:
dict: A dictionary of model names and corresponding model classes.
"""
models = {}
for importer, modname, ispkg in pkgutil.iter_modules(
models.__path__
):
module = importer.find_module(modname).load_module(
modname
)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj):
models[name] = obj
return models
def query(self, text):
"""
Queries the models based on the given text and returns a dictionary of matching models.
Args:
text (str): The text to search for in the model names.
Returns:
dict: A dictionary of matching model names and corresponding model classes.
"""
return {
name: model
for name, model in self.models.items()
if text in name
}

@ -13,10 +13,20 @@ class QwenVLMultiModal(BaseMultiModalModel):
QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot. QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot.
It inherits from the BaseMultiModalModel class. It inherits from the BaseMultiModalModel class.
Examples:
>>> model = QwenVLMultiModal()
>>> model.run("Hello, how are you?", "https://example.com/image.jpg")
Args:
model_name (str): The name of the model to be used.
device (str): The device to run the model on.
args (tuple): Additional positional arguments.
kwargs (dict): Additional keyword arguments.
quantize (bool): A flag to indicate whether to quantize the model.
return_bounding_boxes (bool): A flag to indicate whether to return bounding boxes for the image.
Examples:
>>> qwen = QwenVLMultiModal()
>>> response = qwen.run("Hello", "https://example.com/image.jpg")
>>> print(response)
""" """
model_name: str = "Qwen/Qwen-VL-Chat" model_name: str = "Qwen/Qwen-VL-Chat"
@ -24,6 +34,7 @@ class QwenVLMultiModal(BaseMultiModalModel):
args: tuple = field(default_factory=tuple) args: tuple = field(default_factory=tuple)
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
quantize: bool = False quantize: bool = False
return_bounding_boxes: bool = False
def __post_init__(self): def __post_init__(self):
""" """
@ -60,19 +71,44 @@ class QwenVLMultiModal(BaseMultiModalModel):
and the image associated with the response (if any). and the image associated with the response (if any).
""" """
try: try:
query = self.tokenizer.from_list_format( if self.return_bounding_boxes:
[ query = self.tokenizer.from_list_format(
{"image": img, "text": text}, [
] {"image": img, "text": text},
) ]
)
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(self.model.device) inputs = self.tokenizer(query, return_tensors="pt")
pred = self.model.generate(**inputs) inputs = inputs.to(self.model.device)
response = self.tokenizer.decode( pred = self.model.generate(**inputs)
pred.cpu()[0], skip_special_tokens=False response = self.tokenizer.decode(
) pred.cpu()[0], skip_special_tokens=False
return response )
image_bb = self.tokenizer.draw_bbox_on_latest_picture(
response
)
if image_bb:
image_bb.save("output.jpg")
else:
print("No bounding boxes found in the image.")
return response, image_bb
else:
query = self.tokenizer.from_list_format(
[
{"image": img, "text": text},
]
)
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(self.model.device)
pred = self.model.generate(**inputs)
response = self.tokenizer.decode(
pred.cpu()[0], skip_special_tokens=False
)
return response
except Exception as error: except Exception as error:
print(f"[ERROR]: [QwenVLMultiModal]: {error}") print(f"[ERROR]: [QwenVLMultiModal]: {error}")

Loading…
Cancel
Save