diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 364d1d7f..635124a6 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -44,7 +44,9 @@ from swarms.models.timm import TimmModel # noqa: E402 from swarms.models.ultralytics_model import ( UltralyticsModel, ) # noqa: E402 - +from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402 +from swarms.models.llava import LavaMultiModal # noqa: E402 +from swarms.models.qwen import QwenVLMultiModal # noqa: E402 # from swarms.models.dalle3 import Dalle3 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 @@ -105,4 +107,7 @@ __all__ = [ "TogetherLLM", "TimmModel", "UltralyticsModel", + "VipLlavaMultiModal", + "LavaMultiModal", + "QwenVLMultiModal", ] diff --git a/swarms/models/llava.py b/swarms/models/llava.py index 605904c3..bcc1b09f 100644 --- a/swarms/models/llava.py +++ b/swarms/models/llava.py @@ -1,82 +1,82 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +import requests +from PIL import Image +from transformers import AutoProcessor, LlavaForConditionalGeneration +from typing import Tuple, Union +from io import BytesIO +from swarms.models.base_multimodal_model import BaseMultiModalModel -class MultiModalLlava: +class LavaMultiModal(BaseMultiModalModel): """ - LLava Model + A class to handle multi-modal inputs (text and image) using the Llava model for conditional generation. + + Attributes: + model_name (str): The name or path of the pre-trained model. + max_length (int): The maximum length of the generated sequence. Args: - model_name_or_path: The model name or path to the model - revision: The revision of the model to use - device: The device to run the model on - max_new_tokens: The maximum number of tokens to generate - do_sample: Whether or not to use sampling - temperature: The temperature of the sampling - top_p: The top p value for sampling - top_k: The top k value for sampling - repetition_penalty: The repetition penalty for sampling - device_map: The device map to use + model_name (str): The name of the pre-trained model. + max_length (int): The maximum length of the generated sequence. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. - Methods: - __call__: Call the model - chat: Interactive chat in terminal + Examples: + >>> model = LavaMultiModal() + >>> model.run("A cat", "https://example.com/cat.jpg") - Example: - >>> from swarms.models.llava import LlavaModel - >>> model = LlavaModel(device="cpu") - >>> model("Hello, I am a robot.") """ def __init__( self, - model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ", - revision="main", - device="cuda", - max_new_tokens=512, - do_sample=True, - temperature=0.7, - top_p=0.95, - top_k=40, - repetition_penalty=1.1, - device_map: str = "auto", - ): - self.device = device - self.model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, - device_map=device_map, - trust_remote_code=False, - revision=revision, - ).to(self.device) + model_name: str = "llava-hf/llava-1.5-7b-hf", + max_length: int = 30, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.model_name = model_name + self.max_length = max_length - self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, use_fast=True - ) - self.pipe = pipeline( - "text-generation", - model=self.model, - tokenizer=self.tokenizer, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, - device=0 if self.device == "cuda" else -1, + self.model = LlavaForConditionalGeneration.from_pretrained( + model_name, *args, **kwargs ) + self.processor = AutoProcessor.from_pretrained(model_name) - def __call__(self, prompt): - """Call the model""" - return self.pipe(prompt)[0]["generated_text"] + def run( + self, text: str, img: str, *args, **kwargs + ) -> Union[str, Tuple[None, str]]: + """ + Processes the input text and image, and generates a response. - def chat(self): - """Interactive chat in terminal""" - print( - "Starting chat with LlavaModel. Type 'exit' to end the" - " session." - ) - while True: - user_input = input("You: ") - if user_input.lower() == "exit": - break - response = self(user_input) - print(f"Model: {response}") + Args: + text (str): The input text for the model. + img (str): The URL of the image to process. + max_length (int): The maximum length of the generated sequence. + + Returns: + Union[str, Tuple[None, str]]: The generated response string or a tuple (None, error message) in case of an error. + """ + try: + response = requests.get(img, stream=True) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + + inputs = self.processor( + text=text, images=image, return_tensors="pt" + ) + + # Generate + generate_ids = self.model.generate( + **inputs, max_length=self.max_length, **kwargs + ) + return self.processor.batch_decode( + generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + *args, + )[0] + + except requests.RequestException as e: + return None, f"Error fetching image: {str(e)}" + except Exception as e: + return None, f"Error during model processing: {str(e)}" diff --git a/swarms/models/medical_sam.py b/swarms/models/medical_sam.py index 01e77c04..8d096ba5 100644 --- a/swarms/models/medical_sam.py +++ b/swarms/models/medical_sam.py @@ -10,6 +10,10 @@ from skimage import transform from torch import Tensor +def sam_model_registry(): + pass + + @dataclass class MedicalSAM: """ diff --git a/swarms/models/odin.py b/swarms/models/odin.py index a6228159..27cb1710 100644 --- a/swarms/models/odin.py +++ b/swarms/models/odin.py @@ -3,7 +3,10 @@ import supervision as sv from ultralytics import YOLO from tqdm import tqdm from swarms.models.base_llm import AbstractLLM -from swarms.utils.download_weights_from_url import download_weights_from_url +from swarms.utils.download_weights_from_url import ( + download_weights_from_url, +) + class Odin(AbstractLLM): """ @@ -13,7 +16,7 @@ class Odin(AbstractLLM): source_weights_path (str): The file path to the YOLO model weights. confidence_threshold (float): The confidence threshold for object detection. iou_threshold (float): The intersection over union (IOU) threshold for object detection. - + Example: >>> odin = Odin( ... source_weights_path="yolo.weights", @@ -21,8 +24,8 @@ class Odin(AbstractLLM): ... iou_threshold=0.7, ... ) >>> odin.run(video="input.mp4") - - + + """ def __init__( @@ -35,12 +38,12 @@ class Odin(AbstractLLM): self.source_weights_path = source_weights_path self.confidence_threshold = confidence_threshold self.iou_threshold = iou_threshold - + if not os.path.exists(self.source_weights_path): download_weights_from_url( - url=source_weights_path, save_path=self.source_weights_path + url=source_weights_path, + save_path=self.source_weights_path, ) - def run(self, video: str, *args, **kwargs): """ @@ -61,9 +64,7 @@ class Odin(AbstractLLM): frame_generator = sv.get_video_frames_generator( source_path=self.source_video ) - video_info = sv.VideoInfo.from_video( - video=video - ) + video_info = sv.VideoInfo.from_video(video=video) with sv.VideoSink( target_path=self.target_video, video_info=video_info diff --git a/swarms/models/qwen.py b/swarms/models/qwen.py new file mode 100644 index 00000000..1533b117 --- /dev/null +++ b/swarms/models/qwen.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass, field +from typing import Optional, Tuple + +from PIL import Image +from transformers import AutoModelForCausalLM, AutoTokenizer + +from swarms.models.base_multimodal_model import BaseMultiModalModel + + +@dataclass +class QwenVLMultiModal(BaseMultiModalModel): + """ + QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot. + It inherits from the BaseMultiModalModel class. + + Examples: + >>> model = QwenVLMultiModal() + >>> model.run("Hello, how are you?", "https://example.com/image.jpg") + + """ + + model_name: str = "Qwen/Qwen-VL-Chat" + device: str = "cuda" + args: tuple = field(default_factory=tuple) + kwargs: dict = field(default_factory=dict) + quantize: bool = False + + def __post_init__(self): + """ + Initializes the QwenVLMultiModal object. + It initializes the tokenizer and the model for the Qwen chatbot. + """ + + if self.quantize: + self.model_name = "Qwen/Qwen-VL-Chat-Int4" + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + trust_remote_code=True, + ).eval() + + def run( + self, text: str, img: str, *args, **kwargs + ) -> Tuple[Optional[str], Optional[Image.Image]]: + """ + Runs the Qwen chatbot model on the given text and image inputs. + + Args: + text (str): The input text for the chatbot. + img (str): The input image for the chatbot. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[Optional[str], Optional[Image.Image]]: A tuple containing the response generated by the chatbot + and the image associated with the response (if any). + """ + try: + query = self.tokenizer.from_list_format( + [ + {"image": img, "text": text}, + ] + ) + + inputs = self.tokenizer(query, return_tensors="pt") + inputs = inputs.to(self.model.device) + pred = self.model.generate(**inputs) + response = self.tokenizer.decode( + pred.cpu()[0], skip_special_tokens=False + ) + return response + except Exception as error: + print(f"[ERROR]: [QwenVLMultiModal]: {error}") + + def chat( + self, text: str, img: str, *args, **kwargs + ) -> tuple[str, list]: + """ + Chat with the model using text and image inputs. + + Args: + text (str): The text input for the chat. + img (str): The image input for the chat. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + tuple[str, list]: A tuple containing the response and chat history. + + Raises: + Exception: If an error occurs during the chat. + + """ + try: + response, history = self.model.chat( + self.tokenizer, + query=f"{img}这是什么", + history=None, + ) + return response, history + except Exception as e: + raise Exception( + "An error occurred during the chat." + ) from e diff --git a/swarms/models/vip_llava.py b/swarms/models/vip_llava.py new file mode 100644 index 00000000..31726275 --- /dev/null +++ b/swarms/models/vip_llava.py @@ -0,0 +1,94 @@ +from io import BytesIO + +import requests +import torch +from PIl import Image +from transformers import ( + AutoProcessor, + VipLlavaForConditionalGeneration, +) + +from swarms.models.base_multimodal_model import BaseMultiModalModel + + +class VipLlavaMultiModal(BaseMultiModalModel): + """ + A multi-modal model for VIP-LLAVA. + + Args: + model_name (str): The name or path of the pre-trained model. + max_new_tokens (int): The maximum number of new tokens to generate. + device_map (str): The device mapping for the model. + torch_dtype: The torch data type for the model. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + model_name: str = "llava-hf/vip-llava-7b-hf", + max_new_tokens: int = 500, + device_map: str = "auto", + torch_dtype=torch.float16, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.model_name = model_name + self.max_new_tokens = max_new_tokens + self.device_map = device_map + self.torch_dtype = torch_dtype + + self.model = VipLlavaForConditionalGeneration.from_pretrained( + model_name, + device_map=device_map, + torch_dtype=torch_dtype, + *args, + **kwargs, + ) + self.processor = AutoProcessor.from_pretrained( + model_name, *args, **kwargs + ) + + def run(self, text: str, img: str, *args, **kwargs): + """ + Run the VIP-LLAVA model. + + Args: + text (str): The input text. + img (str): The URL of the input image. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The generated output text. + tuple: A tuple containing None and the error message if an error occurs. + """ + try: + response = requests.get(img, stream=True) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + + inputs = self.processor( + text=text, + images=image, + return_tensors="pt", + *args, + **kwargs, + ).to(0, self.torch_dtype) + + # Generate + generate_ids = self.model.generate( + **inputs, max_new_tokens=self.max_new_tokens, **kwargs + ) + + return self.processor.decode( + generate_ids[0][len(inputs["input_ids"][0]) :], + skip_special_tokens=True, + ) + + except requests.RequestException as error: + return None, f"Error fetching image: {error}" + + except Exception as error: + return None, f"Error during model inference: {error}" diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index e265a1c8..de8ffb57 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -19,7 +19,9 @@ from swarms.utils.data_to_text import ( data_to_text, ) from swarms.utils.try_except_wrapper import try_except_wrapper - +from swarms.utils.download_weights_from_url import ( + download_weights_from_url, +) __all__ = [ "SubprocessCodeInterpreter", @@ -39,4 +41,5 @@ __all__ = [ "txt_to_text", "data_to_text", "try_except_wrapper", + "download_weights_from_url", ] diff --git a/swarms/utils/download_weights_from_url.py b/swarms/utils/download_weights_from_url.py index bc93d699..b5fa1633 100644 --- a/swarms/utils/download_weights_from_url.py +++ b/swarms/utils/download_weights_from_url.py @@ -1,19 +1,22 @@ -import requests +import requests -def download_weights_from_url(url: str, save_path: str = "models/weights.pth"): + +def download_weights_from_url( + url: str, save_path: str = "models/weights.pth" +): """ Downloads model weights from the given URL and saves them to the specified path. Args: url (str): The URL from which to download the model weights. - save_path (str, optional): The path where the downloaded weights should be saved. + save_path (str, optional): The path where the downloaded weights should be saved. Defaults to "models/weights.pth". """ response = requests.get(url, stream=True) response.raise_for_status() - + with open(save_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) - - print(f"Model weights downloaded and saved to {save_path}") \ No newline at end of file + + print(f"Model weights downloaded and saved to {save_path}") diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py new file mode 100644 index 00000000..28178fc0 --- /dev/null +++ b/tests/models/test_qwen.py @@ -0,0 +1,60 @@ +from unittest.mock import Mock, patch +from swarms.models.qwen import QwenVLMultiModal + + +def test_post_init(): + with patch( + "swarms.models.qwen.AutoTokenizer.from_pretrained" + ) as mock_tokenizer, patch( + "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" + ) as mock_model: + mock_tokenizer.return_value = Mock() + mock_model.return_value = Mock() + + model = QwenVLMultiModal() + mock_tokenizer.assert_called_once_with( + model.model_name, trust_remote_code=True + ) + mock_model.assert_called_once_with( + model.model_name, + device_map=model.device, + trust_remote_code=True, + ) + + +def test_run(): + with patch( + "swarms.models.qwen.AutoTokenizer.from_list_format" + ) as mock_format, patch( + "swarms.models.qwen.AutoTokenizer.__call__" + ) as mock_call, patch( + "swarms.models.qwen.AutoModelForCausalLM.generate" + ) as mock_generate, patch( + "swarms.models.qwen.AutoTokenizer.decode" + ) as mock_decode: + mock_format.return_value = Mock() + mock_call.return_value = Mock() + mock_generate.return_value = Mock() + mock_decode.return_value = "response" + + model = QwenVLMultiModal() + response = model.run( + "Hello, how are you?", "https://example.com/image.jpg" + ) + + assert response == "response" + + +def test_chat(): + with patch( + "swarms.models.qwen.AutoModelForCausalLM.chat" + ) as mock_chat: + mock_chat.return_value = ("response", ["history"]) + + model = QwenVLMultiModal() + response, history = model.chat( + "Hello, how are you?", "https://example.com/image.jpg" + ) + + assert response == "response" + assert history == ["history"]