parent
							
								
									403bed61fe
								
							
						
					
					
						commit
						aea843437e
					
				| @ -1,82 +1,82 @@ | |||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | import requests | ||||||
|  | from PIL import Image | ||||||
|  | from transformers import AutoProcessor, LlavaForConditionalGeneration | ||||||
|  | from typing import Tuple, Union | ||||||
|  | from io import BytesIO | ||||||
|  | from swarms.models.base_multimodal_model import BaseMultiModalModel | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MultiModalLlava: | class LavaMultiModal(BaseMultiModalModel): | ||||||
|     """ |     """ | ||||||
|     LLava Model |     A class to handle multi-modal inputs (text and image) using the Llava model for conditional generation. | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         model_name (str): The name or path of the pre-trained model. | ||||||
|  |         max_length (int): The maximum length of the generated sequence. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         model_name_or_path: The model name or path to the model |         model_name (str): The name of the pre-trained model. | ||||||
|         revision: The revision of the model to use |         max_length (int): The maximum length of the generated sequence. | ||||||
|         device: The device to run the model on |         *args: Additional positional arguments. | ||||||
|         max_new_tokens: The maximum number of tokens to generate |         **kwargs: Additional keyword arguments. | ||||||
|         do_sample: Whether or not to use sampling |  | ||||||
|         temperature: The temperature of the sampling |  | ||||||
|         top_p: The top p value for sampling |  | ||||||
|         top_k: The top k value for sampling |  | ||||||
|         repetition_penalty: The repetition penalty for sampling |  | ||||||
|         device_map: The device map to use |  | ||||||
| 
 | 
 | ||||||
|     Methods: |     Examples: | ||||||
|         __call__: Call the model |     >>> model = LavaMultiModal() | ||||||
|         chat: Interactive chat in terminal |     >>> model.run("A cat", "https://example.com/cat.jpg") | ||||||
| 
 | 
 | ||||||
|     Example: |  | ||||||
|         >>> from swarms.models.llava import LlavaModel |  | ||||||
|         >>> model = LlavaModel(device="cpu") |  | ||||||
|         >>> model("Hello, I am a robot.") |  | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ", |         model_name: str = "llava-hf/llava-1.5-7b-hf", | ||||||
|         revision="main", |         max_length: int = 30, | ||||||
|         device="cuda", |         *args, | ||||||
|         max_new_tokens=512, |         **kwargs, | ||||||
|         do_sample=True, |     ) -> None: | ||||||
|         temperature=0.7, |         super().__init__(*args, **kwargs) | ||||||
|         top_p=0.95, |         self.model_name = model_name | ||||||
|         top_k=40, |         self.max_length = max_length | ||||||
|         repetition_penalty=1.1, |  | ||||||
|         device_map: str = "auto", |  | ||||||
|     ): |  | ||||||
|         self.device = device |  | ||||||
|         self.model = AutoModelForCausalLM.from_pretrained( |  | ||||||
|             model_name_or_path, |  | ||||||
|             device_map=device_map, |  | ||||||
|             trust_remote_code=False, |  | ||||||
|             revision=revision, |  | ||||||
|         ).to(self.device) |  | ||||||
| 
 | 
 | ||||||
|         self.tokenizer = AutoTokenizer.from_pretrained( |         self.model = LlavaForConditionalGeneration.from_pretrained( | ||||||
|             model_name_or_path, use_fast=True |             model_name, *args, **kwargs | ||||||
|         ) |  | ||||||
|         self.pipe = pipeline( |  | ||||||
|             "text-generation", |  | ||||||
|             model=self.model, |  | ||||||
|             tokenizer=self.tokenizer, |  | ||||||
|             max_new_tokens=max_new_tokens, |  | ||||||
|             do_sample=do_sample, |  | ||||||
|             temperature=temperature, |  | ||||||
|             top_p=top_p, |  | ||||||
|             top_k=top_k, |  | ||||||
|             repetition_penalty=repetition_penalty, |  | ||||||
|             device=0 if self.device == "cuda" else -1, |  | ||||||
|         ) |         ) | ||||||
|  |         self.processor = AutoProcessor.from_pretrained(model_name) | ||||||
| 
 | 
 | ||||||
|     def __call__(self, prompt): |     def run( | ||||||
|         """Call the model""" |         self, text: str, img: str, *args, **kwargs | ||||||
|         return self.pipe(prompt)[0]["generated_text"] |     ) -> Union[str, Tuple[None, str]]: | ||||||
|  |         """ | ||||||
|  |         Processes the input text and image, and generates a response. | ||||||
| 
 | 
 | ||||||
|     def chat(self): |         Args: | ||||||
|         """Interactive chat in terminal""" |             text (str): The input text for the model. | ||||||
|         print( |             img (str): The URL of the image to process. | ||||||
|             "Starting chat with LlavaModel. Type 'exit' to end the" |             max_length (int): The maximum length of the generated sequence. | ||||||
|             " session." | 
 | ||||||
|         ) |         Returns: | ||||||
|         while True: |             Union[str, Tuple[None, str]]: The generated response string or a tuple (None, error message) in case of an error. | ||||||
|             user_input = input("You: ") |         """ | ||||||
|             if user_input.lower() == "exit": |         try: | ||||||
|                 break |             response = requests.get(img, stream=True) | ||||||
|             response = self(user_input) |             response.raise_for_status() | ||||||
|             print(f"Model: {response}") |             image = Image.open(BytesIO(response.content)) | ||||||
|  | 
 | ||||||
|  |             inputs = self.processor( | ||||||
|  |                 text=text, images=image, return_tensors="pt" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             # Generate | ||||||
|  |             generate_ids = self.model.generate( | ||||||
|  |                 **inputs, max_length=self.max_length, **kwargs | ||||||
|  |             ) | ||||||
|  |             return self.processor.batch_decode( | ||||||
|  |                 generate_ids, | ||||||
|  |                 skip_special_tokens=True, | ||||||
|  |                 clean_up_tokenization_spaces=False, | ||||||
|  |                 *args, | ||||||
|  |             )[0] | ||||||
|  | 
 | ||||||
|  |         except requests.RequestException as e: | ||||||
|  |             return None, f"Error fetching image: {str(e)}" | ||||||
|  |         except Exception as e: | ||||||
|  |             return None, f"Error during model processing: {str(e)}" | ||||||
|  | |||||||
| @ -0,0 +1,108 @@ | |||||||
|  | from dataclasses import dataclass, field | ||||||
|  | from typing import Optional, Tuple | ||||||
|  | 
 | ||||||
|  | from PIL import Image | ||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  | 
 | ||||||
|  | from swarms.models.base_multimodal_model import BaseMultiModalModel | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class QwenVLMultiModal(BaseMultiModalModel): | ||||||
|  |     """ | ||||||
|  |     QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot. | ||||||
|  |     It inherits from the BaseMultiModalModel class. | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |     >>> model = QwenVLMultiModal() | ||||||
|  |     >>> model.run("Hello, how are you?", "https://example.com/image.jpg") | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     model_name: str = "Qwen/Qwen-VL-Chat" | ||||||
|  |     device: str = "cuda" | ||||||
|  |     args: tuple = field(default_factory=tuple) | ||||||
|  |     kwargs: dict = field(default_factory=dict) | ||||||
|  |     quantize: bool = False | ||||||
|  | 
 | ||||||
|  |     def __post_init__(self): | ||||||
|  |         """ | ||||||
|  |         Initializes the QwenVLMultiModal object. | ||||||
|  |         It initializes the tokenizer and the model for the Qwen chatbot. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if self.quantize: | ||||||
|  |             self.model_name = "Qwen/Qwen-VL-Chat-Int4" | ||||||
|  | 
 | ||||||
|  |         self.tokenizer = AutoTokenizer.from_pretrained( | ||||||
|  |             self.model_name, trust_remote_code=True | ||||||
|  |         ) | ||||||
|  |         self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|  |             self.model_name, | ||||||
|  |             device_map=self.device, | ||||||
|  |             trust_remote_code=True, | ||||||
|  |         ).eval() | ||||||
|  | 
 | ||||||
|  |     def run( | ||||||
|  |         self, text: str, img: str, *args, **kwargs | ||||||
|  |     ) -> Tuple[Optional[str], Optional[Image.Image]]: | ||||||
|  |         """ | ||||||
|  |         Runs the Qwen chatbot model on the given text and image inputs. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             text (str): The input text for the chatbot. | ||||||
|  |             img (str): The input image for the chatbot. | ||||||
|  |             *args: Additional positional arguments. | ||||||
|  |             **kwargs: Additional keyword arguments. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Tuple[Optional[str], Optional[Image.Image]]: A tuple containing the response generated by the chatbot | ||||||
|  |             and the image associated with the response (if any). | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             query = self.tokenizer.from_list_format( | ||||||
|  |                 [ | ||||||
|  |                     {"image": img, "text": text}, | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             inputs = self.tokenizer(query, return_tensors="pt") | ||||||
|  |             inputs = inputs.to(self.model.device) | ||||||
|  |             pred = self.model.generate(**inputs) | ||||||
|  |             response = self.tokenizer.decode( | ||||||
|  |                 pred.cpu()[0], skip_special_tokens=False | ||||||
|  |             ) | ||||||
|  |             return response | ||||||
|  |         except Exception as error: | ||||||
|  |             print(f"[ERROR]: [QwenVLMultiModal]: {error}") | ||||||
|  | 
 | ||||||
|  |     def chat( | ||||||
|  |         self, text: str, img: str, *args, **kwargs | ||||||
|  |     ) -> tuple[str, list]: | ||||||
|  |         """ | ||||||
|  |         Chat with the model using text and image inputs. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             text (str): The text input for the chat. | ||||||
|  |             img (str): The image input for the chat. | ||||||
|  |             *args: Additional positional arguments. | ||||||
|  |             **kwargs: Additional keyword arguments. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             tuple[str, list]: A tuple containing the response and chat history. | ||||||
|  | 
 | ||||||
|  |         Raises: | ||||||
|  |             Exception: If an error occurs during the chat. | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             response, history = self.model.chat( | ||||||
|  |                 self.tokenizer, | ||||||
|  |                 query=f"<img>{img}</img>这是什么", | ||||||
|  |                 history=None, | ||||||
|  |             ) | ||||||
|  |             return response, history | ||||||
|  |         except Exception as e: | ||||||
|  |             raise Exception( | ||||||
|  |                 "An error occurred during the chat." | ||||||
|  |             ) from e | ||||||
| @ -0,0 +1,94 @@ | |||||||
|  | from io import BytesIO | ||||||
|  | 
 | ||||||
|  | import requests | ||||||
|  | import torch | ||||||
|  | from PIl import Image | ||||||
|  | from transformers import ( | ||||||
|  |     AutoProcessor, | ||||||
|  |     VipLlavaForConditionalGeneration, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | from swarms.models.base_multimodal_model import BaseMultiModalModel | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class VipLlavaMultiModal(BaseMultiModalModel): | ||||||
|  |     """ | ||||||
|  |     A multi-modal model for VIP-LLAVA. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         model_name (str): The name or path of the pre-trained model. | ||||||
|  |         max_new_tokens (int): The maximum number of new tokens to generate. | ||||||
|  |         device_map (str): The device mapping for the model. | ||||||
|  |         torch_dtype: The torch data type for the model. | ||||||
|  |         *args: Additional positional arguments. | ||||||
|  |         **kwargs: Additional keyword arguments. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_name: str = "llava-hf/vip-llava-7b-hf", | ||||||
|  |         max_new_tokens: int = 500, | ||||||
|  |         device_map: str = "auto", | ||||||
|  |         torch_dtype=torch.float16, | ||||||
|  |         *args, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.model_name = model_name | ||||||
|  |         self.max_new_tokens = max_new_tokens | ||||||
|  |         self.device_map = device_map | ||||||
|  |         self.torch_dtype = torch_dtype | ||||||
|  | 
 | ||||||
|  |         self.model = VipLlavaForConditionalGeneration.from_pretrained( | ||||||
|  |             model_name, | ||||||
|  |             device_map=device_map, | ||||||
|  |             torch_dtype=torch_dtype, | ||||||
|  |             *args, | ||||||
|  |             **kwargs, | ||||||
|  |         ) | ||||||
|  |         self.processor = AutoProcessor.from_pretrained( | ||||||
|  |             model_name, *args, **kwargs | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def run(self, text: str, img: str, *args, **kwargs): | ||||||
|  |         """ | ||||||
|  |         Run the VIP-LLAVA model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             text (str): The input text. | ||||||
|  |             img (str): The URL of the input image. | ||||||
|  |             *args: Additional positional arguments. | ||||||
|  |             **kwargs: Additional keyword arguments. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             str: The generated output text. | ||||||
|  |             tuple: A tuple containing None and the error message if an error occurs. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             response = requests.get(img, stream=True) | ||||||
|  |             response.raise_for_status() | ||||||
|  |             image = Image.open(BytesIO(response.content)) | ||||||
|  | 
 | ||||||
|  |             inputs = self.processor( | ||||||
|  |                 text=text, | ||||||
|  |                 images=image, | ||||||
|  |                 return_tensors="pt", | ||||||
|  |                 *args, | ||||||
|  |                 **kwargs, | ||||||
|  |             ).to(0, self.torch_dtype) | ||||||
|  | 
 | ||||||
|  |             # Generate | ||||||
|  |             generate_ids = self.model.generate( | ||||||
|  |                 **inputs, max_new_tokens=self.max_new_tokens, **kwargs | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             return self.processor.decode( | ||||||
|  |                 generate_ids[0][len(inputs["input_ids"][0]) :], | ||||||
|  |                 skip_special_tokens=True, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         except requests.RequestException as error: | ||||||
|  |             return None, f"Error fetching image: {error}" | ||||||
|  | 
 | ||||||
|  |         except Exception as error: | ||||||
|  |             return None, f"Error during model inference: {error}" | ||||||
| @ -1,19 +1,22 @@ | |||||||
| import requests  | import requests | ||||||
| 
 | 
 | ||||||
| def download_weights_from_url(url: str, save_path: str = "models/weights.pth"): | 
 | ||||||
|  | def download_weights_from_url( | ||||||
|  |     url: str, save_path: str = "models/weights.pth" | ||||||
|  | ): | ||||||
|     """ |     """ | ||||||
|     Downloads model weights from the given URL and saves them to the specified path. |     Downloads model weights from the given URL and saves them to the specified path. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         url (str): The URL from which to download the model weights. |         url (str): The URL from which to download the model weights. | ||||||
|         save_path (str, optional): The path where the downloaded weights should be saved.  |         save_path (str, optional): The path where the downloaded weights should be saved. | ||||||
|             Defaults to "models/weights.pth". |             Defaults to "models/weights.pth". | ||||||
|     """ |     """ | ||||||
|     response = requests.get(url, stream=True) |     response = requests.get(url, stream=True) | ||||||
|     response.raise_for_status() |     response.raise_for_status() | ||||||
|      | 
 | ||||||
|     with open(save_path, "wb") as f: |     with open(save_path, "wb") as f: | ||||||
|         for chunk in response.iter_content(chunk_size=8192): |         for chunk in response.iter_content(chunk_size=8192): | ||||||
|             f.write(chunk) |             f.write(chunk) | ||||||
|          | 
 | ||||||
|     print(f"Model weights downloaded and saved to {save_path}") |     print(f"Model weights downloaded and saved to {save_path}") | ||||||
|  | |||||||
| @ -0,0 +1,60 @@ | |||||||
|  | from unittest.mock import Mock, patch | ||||||
|  | from swarms.models.qwen import QwenVLMultiModal | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_post_init(): | ||||||
|  |     with patch( | ||||||
|  |         "swarms.models.qwen.AutoTokenizer.from_pretrained" | ||||||
|  |     ) as mock_tokenizer, patch( | ||||||
|  |         "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" | ||||||
|  |     ) as mock_model: | ||||||
|  |         mock_tokenizer.return_value = Mock() | ||||||
|  |         mock_model.return_value = Mock() | ||||||
|  | 
 | ||||||
|  |         model = QwenVLMultiModal() | ||||||
|  |         mock_tokenizer.assert_called_once_with( | ||||||
|  |             model.model_name, trust_remote_code=True | ||||||
|  |         ) | ||||||
|  |         mock_model.assert_called_once_with( | ||||||
|  |             model.model_name, | ||||||
|  |             device_map=model.device, | ||||||
|  |             trust_remote_code=True, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_run(): | ||||||
|  |     with patch( | ||||||
|  |         "swarms.models.qwen.AutoTokenizer.from_list_format" | ||||||
|  |     ) as mock_format, patch( | ||||||
|  |         "swarms.models.qwen.AutoTokenizer.__call__" | ||||||
|  |     ) as mock_call, patch( | ||||||
|  |         "swarms.models.qwen.AutoModelForCausalLM.generate" | ||||||
|  |     ) as mock_generate, patch( | ||||||
|  |         "swarms.models.qwen.AutoTokenizer.decode" | ||||||
|  |     ) as mock_decode: | ||||||
|  |         mock_format.return_value = Mock() | ||||||
|  |         mock_call.return_value = Mock() | ||||||
|  |         mock_generate.return_value = Mock() | ||||||
|  |         mock_decode.return_value = "response" | ||||||
|  | 
 | ||||||
|  |         model = QwenVLMultiModal() | ||||||
|  |         response = model.run( | ||||||
|  |             "Hello, how are you?", "https://example.com/image.jpg" | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         assert response == "response" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_chat(): | ||||||
|  |     with patch( | ||||||
|  |         "swarms.models.qwen.AutoModelForCausalLM.chat" | ||||||
|  |     ) as mock_chat: | ||||||
|  |         mock_chat.return_value = ("response", ["history"]) | ||||||
|  | 
 | ||||||
|  |         model = QwenVLMultiModal() | ||||||
|  |         response, history = model.chat( | ||||||
|  |             "Hello, how are you?", "https://example.com/image.jpg" | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         assert response == "response" | ||||||
|  |         assert history == ["history"] | ||||||
					Loading…
					
					
				
		Reference in new issue