diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index f798fab2..f6e46c31 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -2,4 +2,10 @@ from swarms.models.anthropic import Anthropic from swarms.models.petals import Petals from swarms.models.mistral import Mistral from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat -# from swarms.models.fuyu import Fuyu \ No newline at end of file + + + +# MultiModal Models +from swarms.models.idefics import Idefics +from swarms.models.kosmos_two import Kosmos +from swarms.models.vilt import Vilt \ No newline at end of file diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py new file mode 100644 index 00000000..fd790d37 --- /dev/null +++ b/swarms/models/idefics.py @@ -0,0 +1,223 @@ +import torch +from transformers import AutoProcessor, IdeficsForVisionText2Text + + +class Idefics: + """ + + A class for multimodal inference using pre-trained models from the Hugging Face Hub. + + Attributes + ---------- + device : str + The device to use for inference. + checkpoint : str, optional + The name of the pre-trained model checkpoint (default is "HuggingFaceM4/idefics-9b-instruct"). + processor : transformers.PreTrainedProcessor + The pre-trained processor. + max_length : int + The maximum length of the generated text. + chat_history : list + The chat history. + + Methods + ------- + infer(prompts, batched_mode=True) + Generates text based on the provided prompts. + chat(user_input) + Engages in a continuous bidirectional conversation based on the user input. + set_checkpoint(checkpoint) + Changes the model checkpoint. + set_device(device) + Changes the device used for inference. + set_max_length(max_length) + Changes the maximum length of the generated text. + clear_chat_history() + Clears the chat history. + + + # Usage + ``` + from exa import idefics + mmi = idefics() + + user_input = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" + response = mmi.chat(user_input) + print(response) + + user_input = "User: And who is that? https://static.wikia.nocookie.net/asterix/images/2/25/R22b.gif/revision/latest?cb=20110815073052" + response = mmi.chat(user_input) + print(response) + + mmi.set_checkpoint("new_checkpoint") + mmi.set_device("cpu") + mmi.set_max_length(200) + mmi.clear_chat_history() + ``` + + """ + + def __init__( + self, + checkpoint="HuggingFaceM4/idefics-9b-instruct", + device=None, + torch_dtype=torch.bfloat16, + max_length=100, + ): + self.device = ( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) + self.model = IdeficsForVisionText2Text.from_pretrained( + checkpoint, + torch_dtype=torch_dtype, + ).to(self.device) + + self.processor = AutoProcessor.from_pretrained(checkpoint) + + self.max_length = max_length + + self.chat_history = [] + + def run(self, prompts, batched_mode=True): + """ + Generates text based on the provided prompts. + + Parameters + ---------- + prompts : list + A list of prompts. Each prompt is a list of text strings and images. + batched_mode : bool, optional + Whether to process the prompts in batched mode. If True, all prompts are processed together. If False, only the first prompt is processed (default is True). + + Returns + ------- + list + A list of generated text strings. + """ + inputs = ( + self.processor( + prompts, add_end_of_utterance_token=False, return_tensors="pt" + ).to(self.device) + if batched_mode + else self.processor(prompts[0], return_tensors="pt").to(self.device) + ) + + exit_condition = self.processor.tokenizer( + "", add_special_tokens=False + ).input_ids + + bad_words_ids = self.processor.tokenizer( + ["", "", add_special_tokens=False + ).input_ids + + bad_words_ids = self.processor.tokenizer( + ["", " x4 or y2 < y3 or y1 > y4) + + +class Kosmos: + """ + + Args: + + + + # Initialize Kosmos + kosmos = Kosmos() + + # Perform multimodal grounding + kosmos.multimodal_grounding("Find the red apple in the image.", "https://example.com/apple.jpg") + + # Perform referring expression comprehension + kosmos.referring_expression_comprehension("Show me the green bottle.", "https://example.com/bottle.jpg") + + # Generate referring expressions + kosmos.referring_expression_generation("It is on the table.", "https://example.com/table.jpg") + + # Perform grounded visual question answering + kosmos.grounded_vqa("What is the color of the car?", "https://example.com/car.jpg") + + # Generate grounded image caption + kosmos.grounded_image_captioning("https://example.com/beach.jpg") + """ + + def __init__( + self, + model_name="ydshieh/kosmos-2-patch14-224", + ): + self.model = AutoModelForVision2Seq.from_pretrained( + model_name, trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True + ) + + def get_image(self, url): + """Image""" + return Image.open(requests.get(url, stream=True).raw) + + def run(self, prompt, image): + """Run Kosmos""" + inputs = self.processor(text=prompt, images=image, return_tensors="pt") + generated_ids = self.model.generate( + pixel_values=inputs["pixel_values"], + input_ids=inputs["input_ids"][:, :-1], + attention_mask=inputs["attention_mask"][:, :-1], + img_features=None, + img_attn_mask=inputs["img_attn_mask"][:, :-1], + use_cache=True, + max_new_tokens=64, + ) + generated_texts = self.processor.batch_decode( + generated_ids, + skip_special_tokens=True, + )[0] + processed_text, entities = self.processor.post_process_generation( + generated_texts + ) + + def __call__(self, prompt, image): + """Run call""" + inputs = self.processor(text=prompt, images=image, return_tensors="pt") + generated_ids = self.model.generate( + pixel_values=inputs["pixel_values"], + input_ids=inputs["input_ids"][:, :-1], + attention_mask=inputs["attention_mask"][:, :-1], + img_features=None, + img_attn_mask=inputs["img_attn_mask"][:, :-1], + use_cache=True, + max_new_tokens=64, + ) + generated_texts = self.processor.batch_decode( + generated_ids, + skip_special_tokens=True, + )[0] + processed_text, entities = self.processor.post_process_generation( + generated_texts + ) + + # tasks + def multimodal_grounding(self, phrase, image_url): + prompt = f" {phrase} " + self.run(prompt, image_url) + + def referring_expression_comprehension(self, phrase, image_url): + prompt = f" {phrase} " + self.run(prompt, image_url) + + def referring_expression_generation(self, phrase, image_url): + prompt = " It is" + self.run(prompt, image_url) + + def grounded_vqa(self, question, image_url): + prompt = f" Question: {question} Answer:" + self.run(prompt, image_url) + + def grounded_image_captioning(self, image_url): + prompt = " An image of" + self.run(prompt, image_url) + + def grounded_image_captioning_detailed(self, image_url): + prompt = " Describe this image in detail" + self.run(prompt, image_url) + + def draw_entity_boxes_on_image(image, entities, show=False, save_path=None): + """_summary_ + Args: + image (_type_): image or image path + collect_entity_location (_type_): _description_ + """ + if isinstance(image, Image.Image): + image_h = image.height + image_w = image.width + image = np.array(image)[:, :, [2, 1, 0]] + elif isinstance(image, str): + if os.path.exists(image): + pil_img = Image.open(image).convert("RGB") + image = np.array(pil_img)[:, :, [2, 1, 0]] + image_h = pil_img.height + image_w = pil_img.width + else: + raise ValueError(f"invaild image path, {image}") + elif isinstance(image, torch.Tensor): + # pdb.set_trace() + image_tensor = image.cpu() + reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[ + :, None, None + ] + reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[ + :, None, None + ] + image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean + pil_img = T.ToPILImage()(image_tensor) + image_h = pil_img.height + image_w = pil_img.width + image = np.array(pil_img)[:, :, [2, 1, 0]] + else: + raise ValueError(f"invaild image format, {type(image)} for {image}") + + if len(entities) == 0: + return image + + new_image = image.copy() + previous_bboxes = [] + # size of text + text_size = 1 + # thickness of text + text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) + box_line = 3 + (c_width, text_height), _ = cv2.getTextSize( + "F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line + ) + base_height = int(text_height * 0.675) + text_offset_original = text_height - base_height + text_spaces = 3 + + for entity_name, (start, end), bboxes in entities: + for x1_norm, y1_norm, x2_norm, y2_norm in bboxes: + orig_x1, orig_y1, orig_x2, orig_y2 = ( + int(x1_norm * image_w), + int(y1_norm * image_h), + int(x2_norm * image_w), + int(y2_norm * image_h), + ) + # draw bbox + # random color + color = tuple(np.random.randint(0, 255, size=3).tolist()) + new_image = cv2.rectangle( + new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line + ) + + l_o, r_o = ( + box_line // 2 + box_line % 2, + box_line // 2 + box_line % 2 + 1, + ) + + x1 = orig_x1 - l_o + y1 = orig_y1 - l_o + + if y1 < text_height + text_offset_original + 2 * text_spaces: + y1 = ( + orig_y1 + + r_o + + text_height + + text_offset_original + + 2 * text_spaces + ) + x1 = orig_x1 + r_o + + # add text background + (text_width, text_height), _ = cv2.getTextSize( + f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line + ) + text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( + x1, + y1 - (text_height + text_offset_original + 2 * text_spaces), + x1 + text_width, + y1, + ) + + for prev_bbox in previous_bboxes: + while is_overlapping( + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox + ): + text_bg_y1 += ( + text_height + text_offset_original + 2 * text_spaces + ) + text_bg_y2 += ( + text_height + text_offset_original + 2 * text_spaces + ) + y1 += text_height + text_offset_original + 2 * text_spaces + + if text_bg_y2 >= image_h: + text_bg_y1 = max( + 0, + image_h + - ( + text_height + text_offset_original + 2 * text_spaces + ), + ) + text_bg_y2 = image_h + y1 = image_h + break + + alpha = 0.5 + for i in range(text_bg_y1, text_bg_y2): + for j in range(text_bg_x1, text_bg_x2): + if i < image_h and j < image_w: + if j < text_bg_x1 + 1.35 * c_width: + # original color + bg_color = color + else: + # white + bg_color = [255, 255, 255] + new_image[i, j] = ( + alpha * new_image[i, j] + + (1 - alpha) * np.array(bg_color) + ).astype(np.uint8) + + cv2.putText( + new_image, + f" {entity_name}", + (x1, y1 - text_offset_original - 1 * text_spaces), + cv2.FONT_HERSHEY_COMPLEX, + text_size, + (0, 0, 0), + text_line, + cv2.LINE_AA, + ) + # previous_locations.append((x1, y1)) + previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) + + pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) + if save_path: + pil_image.save(save_path) + if show: + pil_image.show() + + return new_image + + def generate_boxees(self, prompt, image_url): + image = self.get_image(image_url) + processed_text, entities = self.process_prompt(prompt, image) + self.draw_entity_boxes_on_image(image, entities, show=True) diff --git a/swarms/models/vilt.py b/swarms/models/vilt.py new file mode 100644 index 00000000..e1677358 --- /dev/null +++ b/swarms/models/vilt.py @@ -0,0 +1,51 @@ +from transformers import ViltProcessor, ViltForQuestionAnswering +import requests +from PIL import Image + + +class Vilt: + """ + Vision-and-Language Transformer (ViLT) model fine-tuned on VQAv2. + It was introduced in the paper ViLT: Vision-and-Language Transformer Without + Convolution or Region Supervision by Kim et al. and first released in this repository. + + Disclaimer: The team releasing ViLT did not write a model card for this model + so this model card has been written by the Hugging Face team. + + https://huggingface.co/dandelin/vilt-b32-finetuned-vqa + + + Example: + >>> model = Vilt() + >>> output = model("What is this image", "http://images.cocodataset.org/val2017/000000039769.jpg") + + """ + + def __init__(self): + self.processor = ViltProcessor.from_pretrained( + "dandelin/vilt-b32-finetuned-vqa" + ) + self.model = ViltForQuestionAnswering.from_pretrained( + "dandelin/vilt-b32-finetuned-vqa" + ) + + def __call__(self, text: str, image_url: str): + """ + Run the model + + + Args: + text: str + + + """ + # Download the image + image = Image.open(requests.get(image_url, stream=True).raw) + + encoding = self.processor(image, text, return_tensors="pt") + + # Forward pass + outputs = self.model(**encoding) + logits = outputs.logits + idx = logits.argmax(-1).item() + print("Predicted Answer:", self.model.config.id2label[idx])