parent
403bed61fe
commit
aea843437e
@ -1,82 +1,82 @@
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||||
|
from typing import Tuple, Union
|
||||||
|
from io import BytesIO
|
||||||
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||||
|
|
||||||
|
|
||||||
class MultiModalLlava:
|
class LavaMultiModal(BaseMultiModalModel):
|
||||||
"""
|
"""
|
||||||
LLava Model
|
A class to handle multi-modal inputs (text and image) using the Llava model for conditional generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_name (str): The name or path of the pre-trained model.
|
||||||
|
max_length (int): The maximum length of the generated sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name_or_path: The model name or path to the model
|
model_name (str): The name of the pre-trained model.
|
||||||
revision: The revision of the model to use
|
max_length (int): The maximum length of the generated sequence.
|
||||||
device: The device to run the model on
|
*args: Additional positional arguments.
|
||||||
max_new_tokens: The maximum number of tokens to generate
|
**kwargs: Additional keyword arguments.
|
||||||
do_sample: Whether or not to use sampling
|
|
||||||
temperature: The temperature of the sampling
|
|
||||||
top_p: The top p value for sampling
|
|
||||||
top_k: The top k value for sampling
|
|
||||||
repetition_penalty: The repetition penalty for sampling
|
|
||||||
device_map: The device map to use
|
|
||||||
|
|
||||||
Methods:
|
Examples:
|
||||||
__call__: Call the model
|
>>> model = LavaMultiModal()
|
||||||
chat: Interactive chat in terminal
|
>>> model.run("A cat", "https://example.com/cat.jpg")
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from swarms.models.llava import LlavaModel
|
|
||||||
>>> model = LlavaModel(device="cpu")
|
|
||||||
>>> model("Hello, I am a robot.")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ",
|
model_name: str = "llava-hf/llava-1.5-7b-hf",
|
||||||
revision="main",
|
max_length: int = 30,
|
||||||
device="cuda",
|
*args,
|
||||||
max_new_tokens=512,
|
**kwargs,
|
||||||
do_sample=True,
|
) -> None:
|
||||||
temperature=0.7,
|
super().__init__(*args, **kwargs)
|
||||||
top_p=0.95,
|
self.model_name = model_name
|
||||||
top_k=40,
|
self.max_length = max_length
|
||||||
repetition_penalty=1.1,
|
|
||||||
device_map: str = "auto",
|
|
||||||
):
|
|
||||||
self.device = device
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
device_map=device_map,
|
|
||||||
trust_remote_code=False,
|
|
||||||
revision=revision,
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.model = LlavaForConditionalGeneration.from_pretrained(
|
||||||
model_name_or_path, use_fast=True
|
model_name, *args, **kwargs
|
||||||
)
|
|
||||||
self.pipe = pipeline(
|
|
||||||
"text-generation",
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=do_sample,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
device=0 if self.device == "cuda" else -1,
|
|
||||||
)
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
|
||||||
def __call__(self, prompt):
|
def run(
|
||||||
"""Call the model"""
|
self, text: str, img: str, *args, **kwargs
|
||||||
return self.pipe(prompt)[0]["generated_text"]
|
) -> Union[str, Tuple[None, str]]:
|
||||||
|
"""
|
||||||
|
Processes the input text and image, and generates a response.
|
||||||
|
|
||||||
def chat(self):
|
Args:
|
||||||
"""Interactive chat in terminal"""
|
text (str): The input text for the model.
|
||||||
print(
|
img (str): The URL of the image to process.
|
||||||
"Starting chat with LlavaModel. Type 'exit' to end the"
|
max_length (int): The maximum length of the generated sequence.
|
||||||
" session."
|
|
||||||
)
|
Returns:
|
||||||
while True:
|
Union[str, Tuple[None, str]]: The generated response string or a tuple (None, error message) in case of an error.
|
||||||
user_input = input("You: ")
|
"""
|
||||||
if user_input.lower() == "exit":
|
try:
|
||||||
break
|
response = requests.get(img, stream=True)
|
||||||
response = self(user_input)
|
response.raise_for_status()
|
||||||
print(f"Model: {response}")
|
image = Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=text, images=image, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
generate_ids = self.model.generate(
|
||||||
|
**inputs, max_length=self.max_length, **kwargs
|
||||||
|
)
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generate_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
*args,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return None, f"Error fetching image: {str(e)}"
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Error during model processing: {str(e)}"
|
||||||
|
@ -0,0 +1,108 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QwenVLMultiModal(BaseMultiModalModel):
|
||||||
|
"""
|
||||||
|
QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot.
|
||||||
|
It inherits from the BaseMultiModalModel class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> model = QwenVLMultiModal()
|
||||||
|
>>> model.run("Hello, how are you?", "https://example.com/image.jpg")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name: str = "Qwen/Qwen-VL-Chat"
|
||||||
|
device: str = "cuda"
|
||||||
|
args: tuple = field(default_factory=tuple)
|
||||||
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
quantize: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the QwenVLMultiModal object.
|
||||||
|
It initializes the tokenizer and the model for the Qwen chatbot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.quantize:
|
||||||
|
self.model_name = "Qwen/Qwen-VL-Chat-Int4"
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name, trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
device_map=self.device,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, text: str, img: str, *args, **kwargs
|
||||||
|
) -> Tuple[Optional[str], Optional[Image.Image]]:
|
||||||
|
"""
|
||||||
|
Runs the Qwen chatbot model on the given text and image inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text for the chatbot.
|
||||||
|
img (str): The input image for the chatbot.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[str], Optional[Image.Image]]: A tuple containing the response generated by the chatbot
|
||||||
|
and the image associated with the response (if any).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = self.tokenizer.from_list_format(
|
||||||
|
[
|
||||||
|
{"image": img, "text": text},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(query, return_tensors="pt")
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
pred = self.model.generate(**inputs)
|
||||||
|
response = self.tokenizer.decode(
|
||||||
|
pred.cpu()[0], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as error:
|
||||||
|
print(f"[ERROR]: [QwenVLMultiModal]: {error}")
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self, text: str, img: str, *args, **kwargs
|
||||||
|
) -> tuple[str, list]:
|
||||||
|
"""
|
||||||
|
Chat with the model using text and image inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The text input for the chat.
|
||||||
|
img (str): The image input for the chat.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, list]: A tuple containing the response and chat history.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs during the chat.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response, history = self.model.chat(
|
||||||
|
self.tokenizer,
|
||||||
|
query=f"<img>{img}</img>这是什么",
|
||||||
|
history=None,
|
||||||
|
)
|
||||||
|
return response, history
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
"An error occurred during the chat."
|
||||||
|
) from e
|
@ -0,0 +1,94 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from PIl import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
VipLlavaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||||
|
|
||||||
|
|
||||||
|
class VipLlavaMultiModal(BaseMultiModalModel):
|
||||||
|
"""
|
||||||
|
A multi-modal model for VIP-LLAVA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name or path of the pre-trained model.
|
||||||
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||||
|
device_map (str): The device mapping for the model.
|
||||||
|
torch_dtype: The torch data type for the model.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "llava-hf/vip-llava-7b-hf",
|
||||||
|
max_new_tokens: int = 500,
|
||||||
|
device_map: str = "auto",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_new_tokens = max_new_tokens
|
||||||
|
self.device_map = device_map
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
self.model = VipLlavaForConditionalGeneration.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
device_map=device_map,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_name, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, text: str, img: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the VIP-LLAVA model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text.
|
||||||
|
img (str): The URL of the input image.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated output text.
|
||||||
|
tuple: A tuple containing None and the error message if an error occurs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(img, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
image = Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=text,
|
||||||
|
images=image,
|
||||||
|
return_tensors="pt",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
).to(0, self.torch_dtype)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
generate_ids = self.model.generate(
|
||||||
|
**inputs, max_new_tokens=self.max_new_tokens, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.processor.decode(
|
||||||
|
generate_ids[0][len(inputs["input_ids"][0]) :],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.RequestException as error:
|
||||||
|
return None, f"Error fetching image: {error}"
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
return None, f"Error during model inference: {error}"
|
@ -1,19 +1,22 @@
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
def download_weights_from_url(url: str, save_path: str = "models/weights.pth"):
|
|
||||||
|
def download_weights_from_url(
|
||||||
|
url: str, save_path: str = "models/weights.pth"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Downloads model weights from the given URL and saves them to the specified path.
|
Downloads model weights from the given URL and saves them to the specified path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url (str): The URL from which to download the model weights.
|
url (str): The URL from which to download the model weights.
|
||||||
save_path (str, optional): The path where the downloaded weights should be saved.
|
save_path (str, optional): The path where the downloaded weights should be saved.
|
||||||
Defaults to "models/weights.pth".
|
Defaults to "models/weights.pth".
|
||||||
"""
|
"""
|
||||||
response = requests.get(url, stream=True)
|
response = requests.get(url, stream=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(save_path, "wb") as f:
|
with open(save_path, "wb") as f:
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
|
|
||||||
print(f"Model weights downloaded and saved to {save_path}")
|
print(f"Model weights downloaded and saved to {save_path}")
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from swarms.models.qwen import QwenVLMultiModal
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_init():
|
||||||
|
with patch(
|
||||||
|
"swarms.models.qwen.AutoTokenizer.from_pretrained"
|
||||||
|
) as mock_tokenizer, patch(
|
||||||
|
"swarms.models.qwen.AutoModelForCausalLM.from_pretrained"
|
||||||
|
) as mock_model:
|
||||||
|
mock_tokenizer.return_value = Mock()
|
||||||
|
mock_model.return_value = Mock()
|
||||||
|
|
||||||
|
model = QwenVLMultiModal()
|
||||||
|
mock_tokenizer.assert_called_once_with(
|
||||||
|
model.model_name, trust_remote_code=True
|
||||||
|
)
|
||||||
|
mock_model.assert_called_once_with(
|
||||||
|
model.model_name,
|
||||||
|
device_map=model.device,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run():
|
||||||
|
with patch(
|
||||||
|
"swarms.models.qwen.AutoTokenizer.from_list_format"
|
||||||
|
) as mock_format, patch(
|
||||||
|
"swarms.models.qwen.AutoTokenizer.__call__"
|
||||||
|
) as mock_call, patch(
|
||||||
|
"swarms.models.qwen.AutoModelForCausalLM.generate"
|
||||||
|
) as mock_generate, patch(
|
||||||
|
"swarms.models.qwen.AutoTokenizer.decode"
|
||||||
|
) as mock_decode:
|
||||||
|
mock_format.return_value = Mock()
|
||||||
|
mock_call.return_value = Mock()
|
||||||
|
mock_generate.return_value = Mock()
|
||||||
|
mock_decode.return_value = "response"
|
||||||
|
|
||||||
|
model = QwenVLMultiModal()
|
||||||
|
response = model.run(
|
||||||
|
"Hello, how are you?", "https://example.com/image.jpg"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == "response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat():
|
||||||
|
with patch(
|
||||||
|
"swarms.models.qwen.AutoModelForCausalLM.chat"
|
||||||
|
) as mock_chat:
|
||||||
|
mock_chat.return_value = ("response", ["history"])
|
||||||
|
|
||||||
|
model = QwenVLMultiModal()
|
||||||
|
response, history = model.chat(
|
||||||
|
"Hello, how are you?", "https://example.com/image.jpg"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == "response"
|
||||||
|
assert history == ["history"]
|
Loading…
Reference in new issue