parent
403bed61fe
commit
aea843437e
@ -1,82 +1,82 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
from typing import Tuple, Union
|
||||
from io import BytesIO
|
||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||
|
||||
|
||||
class MultiModalLlava:
|
||||
class LavaMultiModal(BaseMultiModalModel):
|
||||
"""
|
||||
LLava Model
|
||||
A class to handle multi-modal inputs (text and image) using the Llava model for conditional generation.
|
||||
|
||||
Attributes:
|
||||
model_name (str): The name or path of the pre-trained model.
|
||||
max_length (int): The maximum length of the generated sequence.
|
||||
|
||||
Args:
|
||||
model_name_or_path: The model name or path to the model
|
||||
revision: The revision of the model to use
|
||||
device: The device to run the model on
|
||||
max_new_tokens: The maximum number of tokens to generate
|
||||
do_sample: Whether or not to use sampling
|
||||
temperature: The temperature of the sampling
|
||||
top_p: The top p value for sampling
|
||||
top_k: The top k value for sampling
|
||||
repetition_penalty: The repetition penalty for sampling
|
||||
device_map: The device map to use
|
||||
model_name (str): The name of the pre-trained model.
|
||||
max_length (int): The maximum length of the generated sequence.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
__call__: Call the model
|
||||
chat: Interactive chat in terminal
|
||||
Examples:
|
||||
>>> model = LavaMultiModal()
|
||||
>>> model.run("A cat", "https://example.com/cat.jpg")
|
||||
|
||||
Example:
|
||||
>>> from swarms.models.llava import LlavaModel
|
||||
>>> model = LlavaModel(device="cpu")
|
||||
>>> model("Hello, I am a robot.")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ",
|
||||
revision="main",
|
||||
device="cuda",
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
repetition_penalty=1.1,
|
||||
device_map: str = "auto",
|
||||
):
|
||||
self.device = device
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=False,
|
||||
revision=revision,
|
||||
).to(self.device)
|
||||
model_name: str = "llava-hf/llava-1.5-7b-hf",
|
||||
max_length: int = 30,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_name = model_name
|
||||
self.max_length = max_length
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path, use_fast=True
|
||||
)
|
||||
self.pipe = pipeline(
|
||||
"text-generation",
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=repetition_penalty,
|
||||
device=0 if self.device == "cuda" else -1,
|
||||
self.model = LlavaForConditionalGeneration.from_pretrained(
|
||||
model_name, *args, **kwargs
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
def __call__(self, prompt):
|
||||
"""Call the model"""
|
||||
return self.pipe(prompt)[0]["generated_text"]
|
||||
def run(
|
||||
self, text: str, img: str, *args, **kwargs
|
||||
) -> Union[str, Tuple[None, str]]:
|
||||
"""
|
||||
Processes the input text and image, and generates a response.
|
||||
|
||||
def chat(self):
|
||||
"""Interactive chat in terminal"""
|
||||
print(
|
||||
"Starting chat with LlavaModel. Type 'exit' to end the"
|
||||
" session."
|
||||
)
|
||||
while True:
|
||||
user_input = input("You: ")
|
||||
if user_input.lower() == "exit":
|
||||
break
|
||||
response = self(user_input)
|
||||
print(f"Model: {response}")
|
||||
Args:
|
||||
text (str): The input text for the model.
|
||||
img (str): The URL of the image to process.
|
||||
max_length (int): The maximum length of the generated sequence.
|
||||
|
||||
Returns:
|
||||
Union[str, Tuple[None, str]]: The generated response string or a tuple (None, error message) in case of an error.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(img, stream=True)
|
||||
response.raise_for_status()
|
||||
image = Image.open(BytesIO(response.content))
|
||||
|
||||
inputs = self.processor(
|
||||
text=text, images=image, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate
|
||||
generate_ids = self.model.generate(
|
||||
**inputs, max_length=self.max_length, **kwargs
|
||||
)
|
||||
return self.processor.batch_decode(
|
||||
generate_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
*args,
|
||||
)[0]
|
||||
|
||||
except requests.RequestException as e:
|
||||
return None, f"Error fetching image: {str(e)}"
|
||||
except Exception as e:
|
||||
return None, f"Error during model processing: {str(e)}"
|
||||
|
@ -0,0 +1,108 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenVLMultiModal(BaseMultiModalModel):
|
||||
"""
|
||||
QwenVLMultiModal is a class that represents a multi-modal model for Qwen chatbot.
|
||||
It inherits from the BaseMultiModalModel class.
|
||||
|
||||
Examples:
|
||||
>>> model = QwenVLMultiModal()
|
||||
>>> model.run("Hello, how are you?", "https://example.com/image.jpg")
|
||||
|
||||
"""
|
||||
|
||||
model_name: str = "Qwen/Qwen-VL-Chat"
|
||||
device: str = "cuda"
|
||||
args: tuple = field(default_factory=tuple)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
quantize: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the QwenVLMultiModal object.
|
||||
It initializes the tokenizer and the model for the Qwen chatbot.
|
||||
"""
|
||||
|
||||
if self.quantize:
|
||||
self.model_name = "Qwen/Qwen-VL-Chat-Int4"
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, trust_remote_code=True
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device,
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
|
||||
def run(
|
||||
self, text: str, img: str, *args, **kwargs
|
||||
) -> Tuple[Optional[str], Optional[Image.Image]]:
|
||||
"""
|
||||
Runs the Qwen chatbot model on the given text and image inputs.
|
||||
|
||||
Args:
|
||||
text (str): The input text for the chatbot.
|
||||
img (str): The input image for the chatbot.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[Image.Image]]: A tuple containing the response generated by the chatbot
|
||||
and the image associated with the response (if any).
|
||||
"""
|
||||
try:
|
||||
query = self.tokenizer.from_list_format(
|
||||
[
|
||||
{"image": img, "text": text},
|
||||
]
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(query, return_tensors="pt")
|
||||
inputs = inputs.to(self.model.device)
|
||||
pred = self.model.generate(**inputs)
|
||||
response = self.tokenizer.decode(
|
||||
pred.cpu()[0], skip_special_tokens=False
|
||||
)
|
||||
return response
|
||||
except Exception as error:
|
||||
print(f"[ERROR]: [QwenVLMultiModal]: {error}")
|
||||
|
||||
def chat(
|
||||
self, text: str, img: str, *args, **kwargs
|
||||
) -> tuple[str, list]:
|
||||
"""
|
||||
Chat with the model using text and image inputs.
|
||||
|
||||
Args:
|
||||
text (str): The text input for the chat.
|
||||
img (str): The image input for the chat.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
tuple[str, list]: A tuple containing the response and chat history.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the chat.
|
||||
|
||||
"""
|
||||
try:
|
||||
response, history = self.model.chat(
|
||||
self.tokenizer,
|
||||
query=f"<img>{img}</img>这是什么",
|
||||
history=None,
|
||||
)
|
||||
return response, history
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"An error occurred during the chat."
|
||||
) from e
|
@ -0,0 +1,94 @@
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIl import Image
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
VipLlavaForConditionalGeneration,
|
||||
)
|
||||
|
||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||
|
||||
|
||||
class VipLlavaMultiModal(BaseMultiModalModel):
|
||||
"""
|
||||
A multi-modal model for VIP-LLAVA.
|
||||
|
||||
Args:
|
||||
model_name (str): The name or path of the pre-trained model.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
device_map (str): The device mapping for the model.
|
||||
torch_dtype: The torch data type for the model.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "llava-hf/vip-llava-7b-hf",
|
||||
max_new_tokens: int = 500,
|
||||
device_map: str = "auto",
|
||||
torch_dtype=torch.float16,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_name = model_name
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.device_map = device_map
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
self.model = VipLlavaForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name, *args, **kwargs
|
||||
)
|
||||
|
||||
def run(self, text: str, img: str, *args, **kwargs):
|
||||
"""
|
||||
Run the VIP-LLAVA model.
|
||||
|
||||
Args:
|
||||
text (str): The input text.
|
||||
img (str): The URL of the input image.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The generated output text.
|
||||
tuple: A tuple containing None and the error message if an error occurs.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(img, stream=True)
|
||||
response.raise_for_status()
|
||||
image = Image.open(BytesIO(response.content))
|
||||
|
||||
inputs = self.processor(
|
||||
text=text,
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
*args,
|
||||
**kwargs,
|
||||
).to(0, self.torch_dtype)
|
||||
|
||||
# Generate
|
||||
generate_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=self.max_new_tokens, **kwargs
|
||||
)
|
||||
|
||||
return self.processor.decode(
|
||||
generate_ids[0][len(inputs["input_ids"][0]) :],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
except requests.RequestException as error:
|
||||
return None, f"Error fetching image: {error}"
|
||||
|
||||
except Exception as error:
|
||||
return None, f"Error during model inference: {error}"
|
@ -1,19 +1,22 @@
|
||||
import requests
|
||||
import requests
|
||||
|
||||
def download_weights_from_url(url: str, save_path: str = "models/weights.pth"):
|
||||
|
||||
def download_weights_from_url(
|
||||
url: str, save_path: str = "models/weights.pth"
|
||||
):
|
||||
"""
|
||||
Downloads model weights from the given URL and saves them to the specified path.
|
||||
|
||||
Args:
|
||||
url (str): The URL from which to download the model weights.
|
||||
save_path (str, optional): The path where the downloaded weights should be saved.
|
||||
save_path (str, optional): The path where the downloaded weights should be saved.
|
||||
Defaults to "models/weights.pth".
|
||||
"""
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
with open(save_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"Model weights downloaded and saved to {save_path}")
|
||||
|
||||
print(f"Model weights downloaded and saved to {save_path}")
|
||||
|
@ -0,0 +1,60 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from swarms.models.qwen import QwenVLMultiModal
|
||||
|
||||
|
||||
def test_post_init():
|
||||
with patch(
|
||||
"swarms.models.qwen.AutoTokenizer.from_pretrained"
|
||||
) as mock_tokenizer, patch(
|
||||
"swarms.models.qwen.AutoModelForCausalLM.from_pretrained"
|
||||
) as mock_model:
|
||||
mock_tokenizer.return_value = Mock()
|
||||
mock_model.return_value = Mock()
|
||||
|
||||
model = QwenVLMultiModal()
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
model.model_name, trust_remote_code=True
|
||||
)
|
||||
mock_model.assert_called_once_with(
|
||||
model.model_name,
|
||||
device_map=model.device,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
def test_run():
|
||||
with patch(
|
||||
"swarms.models.qwen.AutoTokenizer.from_list_format"
|
||||
) as mock_format, patch(
|
||||
"swarms.models.qwen.AutoTokenizer.__call__"
|
||||
) as mock_call, patch(
|
||||
"swarms.models.qwen.AutoModelForCausalLM.generate"
|
||||
) as mock_generate, patch(
|
||||
"swarms.models.qwen.AutoTokenizer.decode"
|
||||
) as mock_decode:
|
||||
mock_format.return_value = Mock()
|
||||
mock_call.return_value = Mock()
|
||||
mock_generate.return_value = Mock()
|
||||
mock_decode.return_value = "response"
|
||||
|
||||
model = QwenVLMultiModal()
|
||||
response = model.run(
|
||||
"Hello, how are you?", "https://example.com/image.jpg"
|
||||
)
|
||||
|
||||
assert response == "response"
|
||||
|
||||
|
||||
def test_chat():
|
||||
with patch(
|
||||
"swarms.models.qwen.AutoModelForCausalLM.chat"
|
||||
) as mock_chat:
|
||||
mock_chat.return_value = ("response", ["history"])
|
||||
|
||||
model = QwenVLMultiModal()
|
||||
response, history = model.chat(
|
||||
"Hello, how are you?", "https://example.com/image.jpg"
|
||||
)
|
||||
|
||||
assert response == "response"
|
||||
assert history == ["history"]
|
Loading…
Reference in new issue