parent
f93bc98952
commit
2f88e92930
@ -1,80 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
||||||
from swarms.models.sam_supervision import SegmentAnythingMarkGenerator
|
|
||||||
from swarms.utils.supervision_masking import refine_marks
|
|
||||||
from swarms.utils.supervision_visualizer import MarkVisualizer
|
|
||||||
|
|
||||||
|
|
||||||
class GPT4VSAM(BaseMultiModalModel):
|
|
||||||
"""
|
|
||||||
GPT4VSAM class represents a multi-modal model that combines the capabilities of GPT-4 and SegmentAnythingMarkGenerator.
|
|
||||||
It takes an instance of BaseMultiModalModel (vlm) and a device as input and provides methods for loading images and making predictions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vlm (BaseMultiModalModel): An instance of BaseMultiModalModel representing the visual language model.
|
|
||||||
device (str, optional): The device to be used for computation. Defaults to "cuda".
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
vlm (BaseMultiModalModel): An instance of BaseMultiModalModel representing the visual language model.
|
|
||||||
device (str): The device to be used for computation.
|
|
||||||
sam (SegmentAnythingMarkGenerator): An instance of SegmentAnythingMarkGenerator for generating marks.
|
|
||||||
visualizer (MarkVisualizer): An instance of MarkVisualizer for visualizing marks.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
load_img(img: str) -> Any: Loads an image from the given file path.
|
|
||||||
__call__(task: str, img: str, *args, **kwargs) -> Any: Makes predictions using the visual language model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vlm: BaseMultiModalModel,
|
|
||||||
device: str = "cuda",
|
|
||||||
return_related_marks: bool = False,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.vlm = vlm
|
|
||||||
self.device = device
|
|
||||||
self.return_related_marks = return_related_marks
|
|
||||||
|
|
||||||
self.sam = SegmentAnythingMarkGenerator(device, *args, **kwargs)
|
|
||||||
self.visualizer = MarkVisualizer(*args, **kwargs)
|
|
||||||
|
|
||||||
def load_img(self, img: str) -> Any:
|
|
||||||
"""
|
|
||||||
Loads an image from the given file path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (str): The file path of the image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The loaded image.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return cv2.imread(img)
|
|
||||||
|
|
||||||
def __call__(self, task: str, img: str, *args, **kwargs) -> Any:
|
|
||||||
"""
|
|
||||||
Makes predictions using the visual language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task for which predictions are to be made.
|
|
||||||
img (str): The file path of the image.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The predictions made by the visual language model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
img = self.load_img(img)
|
|
||||||
|
|
||||||
marks = self.sam(image=img)
|
|
||||||
marks = refine_marks(marks=marks)
|
|
||||||
|
|
||||||
return self.vlm(task, img, *args, **kwargs)
|
|
@ -1,244 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from numpy.linalg import norm
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from transformers import (
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BitsAndBytesConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from swarms.models.base_embedding_model import BaseEmbeddingModel
|
|
||||||
|
|
||||||
|
|
||||||
def cos_sim(a, b):
|
|
||||||
return a @ b.T / (norm(a) * norm(b))
|
|
||||||
|
|
||||||
|
|
||||||
class JinaEmbeddings(BaseEmbeddingModel):
|
|
||||||
"""
|
|
||||||
Jina Embeddings model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id (str): The model id to use. Default is "jinaai/jina-embeddings-v2-base-en".
|
|
||||||
device (str): The device to run the model on. Default is "cuda".
|
|
||||||
huggingface_api_key (str): The Hugging Face API key. Default is None.
|
|
||||||
max_length (int): The maximum length of the response. Default is 500.
|
|
||||||
quantize (bool): Whether to quantize the model. Default is False.
|
|
||||||
quantization_config (dict): The quantization configuration. Default is None.
|
|
||||||
verbose (bool): Whether to print verbose logs. Default is False.
|
|
||||||
distributed (bool): Whether to use distributed processing. Default is False.
|
|
||||||
decoding (bool): Whether to use decoding. Default is False.
|
|
||||||
cos_sim (callable): The cosine similarity function. Default is cos_sim.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
run: _description_
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> model = JinaEmbeddings(
|
|
||||||
>>> max_length=8192,
|
|
||||||
>>> device="cuda",
|
|
||||||
>>> quantize=True,
|
|
||||||
>>> huggingface_api_key="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"
|
|
||||||
>>> )
|
|
||||||
>>> embeddings = model("Encode this super long document text")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str = "jinaai/jina-embeddings-v2-base-en",
|
|
||||||
device: str = None,
|
|
||||||
huggingface_api_key: str = None,
|
|
||||||
max_length: int = 500,
|
|
||||||
quantize: bool = False,
|
|
||||||
quantization_config: dict = None,
|
|
||||||
verbose=False,
|
|
||||||
distributed=False,
|
|
||||||
decoding=False,
|
|
||||||
cos_sim: callable = cos_sim,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.device = (
|
|
||||||
device
|
|
||||||
if device
|
|
||||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
)
|
|
||||||
self.huggingface_api_key = huggingface_api_key
|
|
||||||
self.model_id = model_id
|
|
||||||
self.max_length = max_length
|
|
||||||
self.verbose = verbose
|
|
||||||
self.distributed = distributed
|
|
||||||
self.decoding = decoding
|
|
||||||
self.model, self.tokenizer = None, None
|
|
||||||
self.cos_sim = cos_sim
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
assert (
|
|
||||||
torch.cuda.device_count() > 1
|
|
||||||
), "You need more than 1 gpu for distributed processing"
|
|
||||||
|
|
||||||
# If API key then set it
|
|
||||||
if self.huggingface_api_key:
|
|
||||||
os.environ["HF_TOKEN"] = self.huggingface_api_key
|
|
||||||
|
|
||||||
bnb_config = None
|
|
||||||
if quantize:
|
|
||||||
if not quantization_config:
|
|
||||||
quantization_config = {
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"bnb_4bit_use_double_quant": True,
|
|
||||||
"bnb_4bit_quant_type": "nf4",
|
|
||||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
|
||||||
}
|
|
||||||
bnb_config = BitsAndBytesConfig(**quantization_config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_id,
|
|
||||||
quantization_config=bnb_config,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model # .to(self.device)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(
|
|
||||||
f"Failed to load the model or the tokenizer: {e}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
"""Load the model"""
|
|
||||||
if not self.model or not self.tokenizer:
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_id
|
|
||||||
)
|
|
||||||
|
|
||||||
bnb_config = (
|
|
||||||
BitsAndBytesConfig(**self.quantization_config)
|
|
||||||
if self.quantization_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_id,
|
|
||||||
quantization_config=bnb_config,
|
|
||||||
trust_remote_code=True,
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
self.model = DDP(self.model)
|
|
||||||
except Exception as error:
|
|
||||||
self.logger.error(
|
|
||||||
"Failed to load the model or the tokenizer:"
|
|
||||||
f" {error}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Generate a response based on the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- task (str): Text to prompt the model.
|
|
||||||
- max_length (int): Maximum length of the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Generated text (str).
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_length = self.max_length
|
|
||||||
|
|
||||||
try:
|
|
||||||
embeddings = self.model.encode(
|
|
||||||
[task], max_length=max_length, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cos_sim:
|
|
||||||
print(cos_sim(embeddings[0], embeddings[1]))
|
|
||||||
else:
|
|
||||||
return embeddings[0]
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to generate the text: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run_async(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Run the model asynchronously
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): Task to run.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
|
||||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
|
||||||
'Once upon a time in a land far, far away...'
|
|
||||||
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
|
|
||||||
['In the deep jungles,',
|
|
||||||
'At the heart of the city,']
|
|
||||||
>>> mpt_instance.freeze_model()
|
|
||||||
>>> mpt_instance.unfreeze_model()
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Wrapping synchronous calls with async
|
|
||||||
return self.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, task: str, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Generate a response based on the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- task (str): Text to prompt the model.
|
|
||||||
- max_length (int): Maximum length of the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Generated text (str).
|
|
||||||
"""
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
max_length = self.max_length
|
|
||||||
|
|
||||||
try:
|
|
||||||
embeddings = self.model.encode(
|
|
||||||
[task], max_length=max_length, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cos_sim:
|
|
||||||
print(cos_sim(embeddings[0], embeddings[1]))
|
|
||||||
else:
|
|
||||||
return embeddings[0]
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to generate the text: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def __call_async__(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""Call the model asynchronously""" ""
|
|
||||||
return await self.run_async(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def save_model(self, path: str):
|
|
||||||
"""Save the model to a given path"""
|
|
||||||
self.model.save_pretrained(path)
|
|
||||||
self.tokenizer.save_pretrained(path)
|
|
||||||
|
|
||||||
def gpu_available(self) -> bool:
|
|
||||||
"""Check if GPU is available"""
|
|
||||||
return torch.cuda.is_available()
|
|
||||||
|
|
||||||
def memory_consumption(self) -> dict:
|
|
||||||
"""Get the memory consumption of the GPU"""
|
|
||||||
if self.gpu_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
allocated = torch.cuda.memory_allocated()
|
|
||||||
reserved = torch.cuda.memory_reserved()
|
|
||||||
return {"allocated": allocated, "reserved": reserved}
|
|
||||||
else:
|
|
||||||
return {"error": "GPU not available"}
|
|
||||||
|
|
||||||
def try_embed_chunk(self, chunk: str) -> list[float]:
|
|
||||||
return super().try_embed_chunk(chunk)
|
|
@ -1,142 +0,0 @@
|
|||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from skimage import transform
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def sam_model_registry():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MedicalSAM:
|
|
||||||
"""
|
|
||||||
MedicalSAM class for performing semantic segmentation on medical images using the SAM model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model_path (str): The file path to the model weights.
|
|
||||||
device (str): The device to run the model on (default is "cuda:0").
|
|
||||||
model_weights_url (str): The URL to download the model weights from.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
__post_init__(): Initializes the MedicalSAM object.
|
|
||||||
download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path.
|
|
||||||
preprocess(img): Preprocesses the input image.
|
|
||||||
run(img, box): Runs the semantic segmentation on the input image within the specified bounding box.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_path: str
|
|
||||||
device: str = "cuda:0"
|
|
||||||
model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not os.path.exists(self.model_path):
|
|
||||||
self.download_model_weights(self.model_path)
|
|
||||||
|
|
||||||
self.model = sam_model_registry["vit_b"](
|
|
||||||
checkpoint=self.model_path
|
|
||||||
)
|
|
||||||
self.model = self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def download_model_weights(self, model_path: str):
|
|
||||||
"""
|
|
||||||
Downloads the model weights from the specified URL and saves them to the given file path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str): The file path where the model weights will be saved.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If the model weights fail to download.
|
|
||||||
"""
|
|
||||||
response = requests.get(self.model_weights_url, stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
with open(model_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
else:
|
|
||||||
raise Exception("Failed to download model weights.")
|
|
||||||
|
|
||||||
def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]:
|
|
||||||
"""
|
|
||||||
Preprocesses the input image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img: The input image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
img_tensor: The preprocessed image tensor.
|
|
||||||
H: The original height of the image.
|
|
||||||
W: The original width of the image.
|
|
||||||
"""
|
|
||||||
if len(img.shape) == 2:
|
|
||||||
img = np.repeat(img[:, :, None], 3, axis=-1)
|
|
||||||
H, W, _ = img.shape
|
|
||||||
img = transform.resize(
|
|
||||||
img,
|
|
||||||
(1024, 1024),
|
|
||||||
order=3,
|
|
||||||
preserve_range=True,
|
|
||||||
anti_aliasing=True,
|
|
||||||
).astype(np.uint8)
|
|
||||||
img = img - img.min() / np.clip(
|
|
||||||
img.max() - img.min(), a_min=1e-8, a_max=None
|
|
||||||
)
|
|
||||||
img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
|
|
||||||
return img, H, W
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Runs the semantic segmentation on the input image within the specified bounding box.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img: The input image.
|
|
||||||
box: The bounding box coordinates (x1, y1, x2, y2).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
medsam_seg: The segmented image.
|
|
||||||
"""
|
|
||||||
img_tensor, H, W = self.preprocess(img)
|
|
||||||
img_tensor = img_tensor.to(self.device)
|
|
||||||
box_1024 = box / np.array([W, H, W, H]) * 1024
|
|
||||||
img = self.model.image_encoder(img_tensor)
|
|
||||||
|
|
||||||
box_torch = torch.as_tensor(
|
|
||||||
box_1024, dtype=torch.float, device=img_tensor.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(box_torch.shape) == 2:
|
|
||||||
box_torch = box_torch[:, None, :]
|
|
||||||
|
|
||||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
|
||||||
points=None,
|
|
||||||
boxes=box_torch,
|
|
||||||
masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
low_res_logits, _ = self.model.mask_decoder(
|
|
||||||
image_embeddings=img,
|
|
||||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
|
||||||
sparse_prompt_embeddings=sparse_embeddings,
|
|
||||||
dense_prompt_embeddings=dense_embeddings,
|
|
||||||
multimask_output=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
low_res_pred = torch.sigmoid(low_res_logits)
|
|
||||||
low_res_pred = F.interpolate(
|
|
||||||
low_res_pred,
|
|
||||||
size=(H, W),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
low_res_pred = low_res_pred.squeeze().cpu().numpy()
|
|
||||||
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
|
||||||
|
|
||||||
return medsam_seg
|
|
@ -1,10 +0,0 @@
|
|||||||
from swarms.models.popular_llms import OpenAIChat
|
|
||||||
|
|
||||||
|
|
||||||
class MistralAPILLM(OpenAIChat):
|
|
||||||
def __init__(self, url):
|
|
||||||
super().__init__()
|
|
||||||
self.openai_proxy_url = url
|
|
||||||
|
|
||||||
def __call__(self, task: str):
|
|
||||||
super().__call__(task)
|
|
@ -1,182 +1,5 @@
|
|||||||
from __future__ import annotations
|
from langchain_community.llms.google_palm import GooglePalm
|
||||||
|
|
||||||
import logging
|
__all__ = [
|
||||||
from typing import Any, Callable
|
"GooglePalm",
|
||||||
|
]
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
||||||
from langchain.llms import BaseLLM
|
|
||||||
from langchain.pydantic_v1 import BaseModel
|
|
||||||
from langchain.schema import Generation, LLMResult
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
|
||||||
from tenacity import (
|
|
||||||
before_sleep_log,
|
|
||||||
retry,
|
|
||||||
retry_if_exception_type,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
)
|
|
||||||
from pydantic import model_validator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
|
||||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
|
||||||
try:
|
|
||||||
import google.api_core.exceptions
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import google-api-core python package. "
|
|
||||||
"Please install it with `pip install google-api-core`."
|
|
||||||
)
|
|
||||||
|
|
||||||
multiplier = 2
|
|
||||||
min_seconds = 1
|
|
||||||
max_seconds = 60
|
|
||||||
max_retries = 10
|
|
||||||
|
|
||||||
return retry(
|
|
||||||
reraise=True,
|
|
||||||
stop=stop_after_attempt(max_retries),
|
|
||||||
wait=wait_exponential(
|
|
||||||
multiplier=multiplier, min=min_seconds, max=max_seconds
|
|
||||||
),
|
|
||||||
retry=(
|
|
||||||
retry_if_exception_type(
|
|
||||||
google.api_core.exceptions.ResourceExhausted
|
|
||||||
)
|
|
||||||
| retry_if_exception_type(
|
|
||||||
google.api_core.exceptions.ServiceUnavailable
|
|
||||||
)
|
|
||||||
| retry_if_exception_type(
|
|
||||||
google.api_core.exceptions.GoogleAPIError
|
|
||||||
)
|
|
||||||
),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
|
||||||
"""Use tenacity to retry the completion call."""
|
|
||||||
retry_decorator = _create_retry_decorator()
|
|
||||||
|
|
||||||
@retry_decorator
|
|
||||||
def _generate_with_retry(**kwargs: Any) -> Any:
|
|
||||||
return llm.client.generate_text(**kwargs)
|
|
||||||
|
|
||||||
return _generate_with_retry(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
||||||
"""Strip erroneous leading spaces from text.
|
|
||||||
|
|
||||||
The PaLM API will sometimes erroneously return a single leading space in all
|
|
||||||
lines > 1. This function strips that space.
|
|
||||||
"""
|
|
||||||
has_leading_space = all(
|
|
||||||
not line or line[0] == " " for line in text.split("\n")[1:]
|
|
||||||
)
|
|
||||||
if has_leading_space:
|
|
||||||
return text.replace("\n ", "\n")
|
|
||||||
else:
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class GooglePalm(BaseLLM, BaseModel):
|
|
||||||
"""Google PaLM models."""
|
|
||||||
|
|
||||||
client: Any #: :meta private:
|
|
||||||
google_api_key: str | None
|
|
||||||
model_name: str = "models/text-bison-001"
|
|
||||||
"""Model name to use."""
|
|
||||||
temperature: float = 0.7
|
|
||||||
"""Run inference with this temperature. Must by in the closed interval
|
|
||||||
[0.0, 1.0]."""
|
|
||||||
top_p: float | None = None
|
|
||||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
|
||||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
|
||||||
top_k: int | None = None
|
|
||||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
||||||
Must be positive."""
|
|
||||||
max_output_tokens: int | None = None
|
|
||||||
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
|
||||||
If unset, will default to 64."""
|
|
||||||
n: int = 1
|
|
||||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
||||||
not return the full n completions if duplicates are generated."""
|
|
||||||
|
|
||||||
@model_validator()
|
|
||||||
@classmethod
|
|
||||||
def validate_environment(cls, values: dict) -> dict:
|
|
||||||
"""Validate api key, python package exists."""
|
|
||||||
google_api_key = get_from_dict_or_env(
|
|
||||||
values, "google_api_key", "GOOGLE_API_KEY"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
import google.generativeai as genai
|
|
||||||
|
|
||||||
genai.configure(api_key=google_api_key)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import google-generativeai python package."
|
|
||||||
" Please install it with `pip install"
|
|
||||||
" google-generativeai`."
|
|
||||||
)
|
|
||||||
|
|
||||||
values["client"] = genai
|
|
||||||
|
|
||||||
if (
|
|
||||||
values["temperature"] is not None
|
|
||||||
and not 0 <= values["temperature"] <= 1
|
|
||||||
):
|
|
||||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
|
||||||
|
|
||||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
|
||||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
|
||||||
|
|
||||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
|
||||||
raise ValueError("top_k must be positive")
|
|
||||||
|
|
||||||
if (
|
|
||||||
values["max_output_tokens"] is not None
|
|
||||||
and values["max_output_tokens"] <= 0
|
|
||||||
):
|
|
||||||
raise ValueError("max_output_tokens must be greater than zero")
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
prompts: list[str],
|
|
||||||
stop: list[str] | None = None,
|
|
||||||
run_manager: CallbackManagerForLLMRun | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResult:
|
|
||||||
generations = []
|
|
||||||
for prompt in prompts:
|
|
||||||
completion = generate_with_retry(
|
|
||||||
self,
|
|
||||||
model=self.model_name,
|
|
||||||
prompt=prompt,
|
|
||||||
stop_sequences=stop,
|
|
||||||
temperature=self.temperature,
|
|
||||||
top_p=self.top_p,
|
|
||||||
top_k=self.top_k,
|
|
||||||
max_output_tokens=self.max_output_tokens,
|
|
||||||
candidate_count=self.n,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_generations = []
|
|
||||||
for candidate in completion.candidates:
|
|
||||||
raw_text = candidate["output"]
|
|
||||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
|
||||||
prompt_generations.append(Generation(text=stripped_text))
|
|
||||||
generations.append(prompt_generations)
|
|
||||||
|
|
||||||
return LLMResult(generations=generations)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return type of llm."""
|
|
||||||
return "google_palm"
|
|
@ -1,16 +0,0 @@
|
|||||||
"""
|
|
||||||
|
|
||||||
TROCR for Multi-Modal OCR tasks
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class TrOCR:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
pass
|
|
@ -1,52 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from ultralytics import YOLO
|
|
||||||
|
|
||||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
||||||
|
|
||||||
|
|
||||||
class UltralyticsModel(BaseMultiModalModel):
|
|
||||||
"""
|
|
||||||
Initializes an instance of the Ultralytics model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): The name of the model.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "yolov8n.pt", *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model = YOLO(model_name, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to initialize Ultralytics model: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, task: str, tasks: List[str] = None, *args, **kwargs
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Calls the Ultralytics model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task to perform.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of the model call.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if tasks:
|
|
||||||
return self.model([tasks], *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.model(task, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to perform task '{task}' with Ultralytics"
|
|
||||||
f" model: {str(e)}"
|
|
||||||
)
|
|
@ -1,237 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from transformers import (
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BitsAndBytesConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WizardLLMStoryTeller:
|
|
||||||
"""
|
|
||||||
A class for running inference on a given model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model_id (str): The ID of the model.
|
|
||||||
device (str): The device to run the model on (either 'cuda' or 'cpu').
|
|
||||||
max_length (int): The maximum length of the output sequence.
|
|
||||||
quantize (bool, optional): Whether to use quantization. Defaults to False.
|
|
||||||
quantization_config (dict, optional): The configuration for quantization.
|
|
||||||
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
|
||||||
logger (logging.Logger, optional): The logger to use. Defaults to a basic logger.
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
```
|
|
||||||
from finetuning_suite import Inference
|
|
||||||
|
|
||||||
model_id = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF"
|
|
||||||
inference = Inference(model_id=model_id)
|
|
||||||
|
|
||||||
prompt_text = "Once upon a time"
|
|
||||||
generated_text = inference(prompt_text)
|
|
||||||
print(generated_text)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF",
|
|
||||||
device: str = None,
|
|
||||||
max_length: int = 500,
|
|
||||||
quantize: bool = False,
|
|
||||||
quantization_config: dict = None,
|
|
||||||
verbose=False,
|
|
||||||
# logger=None,
|
|
||||||
distributed=False,
|
|
||||||
decoding=False,
|
|
||||||
):
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.device = (
|
|
||||||
device
|
|
||||||
if device
|
|
||||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
)
|
|
||||||
self.model_id = model_id
|
|
||||||
self.max_length = max_length
|
|
||||||
self.verbose = verbose
|
|
||||||
self.distributed = distributed
|
|
||||||
self.decoding = decoding
|
|
||||||
self.model, self.tokenizer = None, None
|
|
||||||
# self.log = Logging()
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
assert (
|
|
||||||
torch.cuda.device_count() > 1
|
|
||||||
), "You need more than 1 gpu for distributed processing"
|
|
||||||
|
|
||||||
bnb_config = None
|
|
||||||
if quantize:
|
|
||||||
if not quantization_config:
|
|
||||||
quantization_config = {
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"bnb_4bit_use_double_quant": True,
|
|
||||||
"bnb_4bit_quant_type": "nf4",
|
|
||||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
|
||||||
}
|
|
||||||
bnb_config = BitsAndBytesConfig(**quantization_config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_id, quantization_config=bnb_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model # .to(self.device)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(
|
|
||||||
f"Failed to load the model or the tokenizer: {e}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
"""Load the model"""
|
|
||||||
if not self.model or not self.tokenizer:
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_id
|
|
||||||
)
|
|
||||||
|
|
||||||
bnb_config = (
|
|
||||||
BitsAndBytesConfig(**self.quantization_config)
|
|
||||||
if self.quantization_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_id, quantization_config=bnb_config
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
self.model = DDP(self.model)
|
|
||||||
except Exception as error:
|
|
||||||
self.logger.error(
|
|
||||||
"Failed to load the model or the tokenizer:"
|
|
||||||
f" {error}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run(self, prompt_text: str):
|
|
||||||
"""
|
|
||||||
Generate a response based on the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- prompt_text (str): Text to prompt the model.
|
|
||||||
- max_length (int): Maximum length of the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Generated text (str).
|
|
||||||
"""
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
max_length = self.max_length
|
|
||||||
|
|
||||||
try:
|
|
||||||
inputs = self.tokenizer.encode(
|
|
||||||
prompt_text, return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# self.log.start()
|
|
||||||
|
|
||||||
if self.decoding:
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(max_length):
|
|
||||||
output_sequence = []
|
|
||||||
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs,
|
|
||||||
max_length=len(inputs) + 1,
|
|
||||||
do_sample=True,
|
|
||||||
)
|
|
||||||
output_tokens = outputs[0][-1]
|
|
||||||
output_sequence.append(output_tokens.item())
|
|
||||||
|
|
||||||
# print token in real-time
|
|
||||||
print(
|
|
||||||
self.tokenizer.decode(
|
|
||||||
[output_tokens],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
),
|
|
||||||
end="",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
inputs = outputs
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs, max_length=max_length, do_sample=True
|
|
||||||
)
|
|
||||||
|
|
||||||
del inputs
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
outputs[0], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to generate the text: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def __call__(self, prompt_text: str):
|
|
||||||
"""
|
|
||||||
Generate a response based on the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- prompt_text (str): Text to prompt the model.
|
|
||||||
- max_length (int): Maximum length of the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Generated text (str).
|
|
||||||
"""
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
max_length = self.max_
|
|
||||||
|
|
||||||
try:
|
|
||||||
inputs = self.tokenizer.encode(
|
|
||||||
prompt_text, return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# self.log.start()
|
|
||||||
|
|
||||||
if self.decoding:
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(max_length):
|
|
||||||
output_sequence = []
|
|
||||||
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs,
|
|
||||||
max_length=len(inputs) + 1,
|
|
||||||
do_sample=True,
|
|
||||||
)
|
|
||||||
output_tokens = outputs[0][-1]
|
|
||||||
output_sequence.append(output_tokens.item())
|
|
||||||
|
|
||||||
# print token in real-time
|
|
||||||
print(
|
|
||||||
self.tokenizer.decode(
|
|
||||||
[output_tokens],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
),
|
|
||||||
end="",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
inputs = outputs
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs, max_length=max_length, do_sample=True
|
|
||||||
)
|
|
||||||
|
|
||||||
del inputs
|
|
||||||
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
outputs[0], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to generate the text: {e}")
|
|
||||||
raise
|
|
@ -1,288 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from transformers import (
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BitsAndBytesConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class YarnMistral128:
|
|
||||||
"""
|
|
||||||
A class for running inference on a given model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model_id (str): The ID of the model.
|
|
||||||
device (str): The device to run the model on (either 'cuda' or 'cpu').
|
|
||||||
max_length (int): The maximum length of the output sequence.
|
|
||||||
quantize (bool, optional): Whether to use quantization. Defaults to False.
|
|
||||||
quantization_config (dict, optional): The configuration for quantization.
|
|
||||||
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
|
||||||
logger (logging.Logger, optional): The logger to use. Defaults to a basic logger.
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
```
|
|
||||||
from finetuning_suite import Inference
|
|
||||||
|
|
||||||
model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
|
||||||
inference = Inference(model_id=model_id)
|
|
||||||
|
|
||||||
prompt_text = "Once upon a time"
|
|
||||||
generated_text = inference(prompt_text)
|
|
||||||
print(generated_text)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str = "NousResearch/Yarn-Mistral-7b-128k",
|
|
||||||
device: str = None,
|
|
||||||
max_length: int = 500,
|
|
||||||
quantize: bool = False,
|
|
||||||
quantization_config: dict = None,
|
|
||||||
verbose=False,
|
|
||||||
# logger=None,
|
|
||||||
distributed=False,
|
|
||||||
decoding=False,
|
|
||||||
):
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.device = (
|
|
||||||
device
|
|
||||||
if device
|
|
||||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
)
|
|
||||||
self.model_id = model_id
|
|
||||||
self.max_length = max_length
|
|
||||||
self.verbose = verbose
|
|
||||||
self.distributed = distributed
|
|
||||||
self.decoding = decoding
|
|
||||||
self.model, self.tokenizer = None, None
|
|
||||||
# self.log = Logging()
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
assert (
|
|
||||||
torch.cuda.device_count() > 1
|
|
||||||
), "You need more than 1 gpu for distributed processing"
|
|
||||||
|
|
||||||
bnb_config = None
|
|
||||||
if quantize:
|
|
||||||
if not quantization_config:
|
|
||||||
quantization_config = {
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"bnb_4bit_use_double_quant": True,
|
|
||||||
"bnb_4bit_quant_type": "nf4",
|
|
||||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
|
||||||
}
|
|
||||||
bnb_config = BitsAndBytesConfig(**quantization_config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_id,
|
|
||||||
quantization_config=bnb_config,
|
|
||||||
use_flash_attention_2=True,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model # .to(self.device)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(
|
|
||||||
f"Failed to load the model or the tokenizer: {e}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
"""Load the model"""
|
|
||||||
if not self.model or not self.tokenizer:
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_id
|
|
||||||
)
|
|
||||||
|
|
||||||
bnb_config = (
|
|
||||||
BitsAndBytesConfig(**self.quantization_config)
|
|
||||||
if self.quantization_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_id, quantization_config=bnb_config
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
self.model = DDP(self.model)
|
|
||||||
except Exception as error:
|
|
||||||
self.logger.error(
|
|
||||||
"Failed to load the model or the tokenizer:"
|
|
||||||
f" {error}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run(self, prompt_text: str):
|
|
||||||
"""
|
|
||||||
Generate a response based on the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- prompt_text (str): Text to prompt the model.
|
|
||||||
- max_length (int): Maximum length of the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Generated text (str).
|
|
||||||
"""
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
max_length = self.max_length
|
|
||||||
|
|
||||||
try:
|
|
||||||
inputs = self.tokenizer.encode(
|
|
||||||
prompt_text, return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# self.log.start()
|
|
||||||
|
|
||||||
if self.decoding:
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(max_length):
|
|
||||||
output_sequence = []
|
|
||||||
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs,
|
|
||||||
max_length=len(inputs) + 1,
|
|
||||||
do_sample=True,
|
|
||||||
)
|
|
||||||
output_tokens = outputs[0][-1]
|
|
||||||
output_sequence.append(output_tokens.item())
|
|
||||||
|
|
||||||
# print token in real-time
|
|
||||||
print(
|
|
||||||
self.tokenizer.decode(
|
|
||||||
[output_tokens],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
),
|
|
||||||
end="",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
inputs = outputs
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs, max_length=max_length, do_sample=True
|
|
||||||
)
|
|
||||||
|
|
||||||
del inputs
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
outputs[0], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to generate the text: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run_async(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Run the model asynchronously
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): Task to run.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
|
||||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
|
||||||
'Once upon a time in a land far, far away...'
|
|
||||||
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
|
|
||||||
['In the deep jungles,',
|
|
||||||
'At the heart of the city,']
|
|
||||||
>>> mpt_instance.freeze_model()
|
|
||||||
>>> mpt_instance.unfreeze_model()
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Wrapping synchronous calls with async
|
|
||||||
return self.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, prompt_text: str):
|
|
||||||
"""
|
|
||||||
Generate a response based on the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- prompt_text (str): Text to prompt the model.
|
|
||||||
- max_length (int): Maximum length of the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Generated text (str).
|
|
||||||
"""
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
max_length = self.max_
|
|
||||||
|
|
||||||
try:
|
|
||||||
inputs = self.tokenizer.encode(
|
|
||||||
prompt_text, return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# self.log.start()
|
|
||||||
|
|
||||||
if self.decoding:
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(max_length):
|
|
||||||
output_sequence = []
|
|
||||||
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs,
|
|
||||||
max_length=len(inputs) + 1,
|
|
||||||
do_sample=True,
|
|
||||||
)
|
|
||||||
output_tokens = outputs[0][-1]
|
|
||||||
output_sequence.append(output_tokens.item())
|
|
||||||
|
|
||||||
# print token in real-time
|
|
||||||
print(
|
|
||||||
self.tokenizer.decode(
|
|
||||||
[output_tokens],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
),
|
|
||||||
end="",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
inputs = outputs
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs, max_length=max_length, do_sample=True
|
|
||||||
)
|
|
||||||
|
|
||||||
del inputs
|
|
||||||
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
outputs[0], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to generate the text: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def __call_async__(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""Call the model asynchronously""" ""
|
|
||||||
return await self.run_async(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def save_model(self, path: str):
|
|
||||||
"""Save the model to a given path"""
|
|
||||||
self.model.save_pretrained(path)
|
|
||||||
self.tokenizer.save_pretrained(path)
|
|
||||||
|
|
||||||
def gpu_available(self) -> bool:
|
|
||||||
"""Check if GPU is available"""
|
|
||||||
return torch.cuda.is_available()
|
|
||||||
|
|
||||||
def memory_consumption(self) -> dict:
|
|
||||||
"""Get the memory consumption of the GPU"""
|
|
||||||
if self.gpu_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
allocated = torch.cuda.memory_allocated()
|
|
||||||
reserved = torch.cuda.memory_reserved()
|
|
||||||
return {"allocated": allocated, "reserved": reserved}
|
|
||||||
else:
|
|
||||||
return {"error": "GPU not available"}
|
|
@ -1,97 +0,0 @@
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class Yi34B200k:
|
|
||||||
"""
|
|
||||||
A class for eaasy interaction with Yi34B200k
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
-----------
|
|
||||||
model_id: str
|
|
||||||
The model id of the model to be used.
|
|
||||||
device_map: str
|
|
||||||
The device to be used for inference.
|
|
||||||
torch_dtype: str
|
|
||||||
The torch dtype to be used for inference.
|
|
||||||
max_length: int
|
|
||||||
The maximum length of the generated text.
|
|
||||||
repitition_penalty: float
|
|
||||||
The repitition penalty to be used for inference.
|
|
||||||
no_repeat_ngram_size: int
|
|
||||||
The no repeat ngram size to be used for inference.
|
|
||||||
temperature: float
|
|
||||||
The temperature to be used for inference.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
--------
|
|
||||||
__call__(self, task: str) -> str:
|
|
||||||
Generates text based on the given prompt.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str = "01-ai/Yi-34B-200K",
|
|
||||||
device_map: str = "auto",
|
|
||||||
torch_dtype: str = "auto",
|
|
||||||
max_length: int = 512,
|
|
||||||
repitition_penalty: float = 1.3,
|
|
||||||
no_repeat_ngram_size: int = 5,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
top_k: int = 40,
|
|
||||||
top_p: float = 0.8,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.model_id = model_id
|
|
||||||
self.device_map = device_map
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.max_length = max_length
|
|
||||||
self.repitition_penalty = repitition_penalty
|
|
||||||
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
||||||
self.temperature = temperature
|
|
||||||
self.top_k = top_k
|
|
||||||
self.top_p = top_p
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
device_map=device_map,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, task: str):
|
|
||||||
"""
|
|
||||||
Generates text based on the given prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The input text prompt.
|
|
||||||
max_length (int): The maximum length of the generated text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The generated text.
|
|
||||||
"""
|
|
||||||
inputs = self.tokenizer(task, return_tensors="pt")
|
|
||||||
outputs = self.model.generate(
|
|
||||||
inputs.input_ids.cuda(),
|
|
||||||
max_length=self.max_length,
|
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
|
||||||
do_sample=True,
|
|
||||||
repetition_penalty=self.repitition_penalty,
|
|
||||||
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
|
||||||
temperature=self.temperature,
|
|
||||||
top_k=self.top_k,
|
|
||||||
top_p=self.top_p,
|
|
||||||
)
|
|
||||||
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage
|
|
||||||
# yi34b = Yi34B200k()
|
|
||||||
# prompt = "There's a place where time stands still. A place of breathtaking wonder, but also"
|
|
||||||
# generated_text = yi34b(prompt)
|
|
||||||
# print(generated_text)
|
|
@ -1,223 +0,0 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
# Import necessary modules
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from transformers import BioGptForCausalLM, BioGptTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# Fixture for BioGPT instance
|
|
||||||
@pytest.fixture
|
|
||||||
def biogpt_instance():
|
|
||||||
from swarms.models import BioGPT
|
|
||||||
|
|
||||||
return BioGPT()
|
|
||||||
|
|
||||||
|
|
||||||
# 36. Test if BioGPT provides a response for a simple biomedical question
|
|
||||||
def test_biomedical_response_1(biogpt_instance):
|
|
||||||
question = "What are the functions of the mitochondria?"
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 37. Test for a genetics-based question
|
|
||||||
def test_genetics_response(biogpt_instance):
|
|
||||||
question = "Can you explain the Mendelian inheritance?"
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 38. Test for a question about viruses
|
|
||||||
def test_virus_response(biogpt_instance):
|
|
||||||
question = "How do RNA viruses replicate?"
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 39. Test for a cell biology related question
|
|
||||||
def test_cell_biology_response(biogpt_instance):
|
|
||||||
question = "Describe the cell cycle and its phases."
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 40. Test for a question about protein structure
|
|
||||||
def test_protein_structure_response(biogpt_instance):
|
|
||||||
question = (
|
|
||||||
"What's the difference between alpha helix and beta sheet"
|
|
||||||
" structures in proteins?"
|
|
||||||
)
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 41. Test for a pharmacology question
|
|
||||||
def test_pharmacology_response(biogpt_instance):
|
|
||||||
question = "How do beta blockers work?"
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 42. Test for an anatomy-based question
|
|
||||||
def test_anatomy_response(biogpt_instance):
|
|
||||||
question = "Describe the structure of the human heart."
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 43. Test for a question about bioinformatics
|
|
||||||
def test_bioinformatics_response(biogpt_instance):
|
|
||||||
question = "What is a BLAST search?"
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 44. Test for a neuroscience question
|
|
||||||
def test_neuroscience_response(biogpt_instance):
|
|
||||||
question = "Explain the function of synapses in the nervous system."
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
# 45. Test for an immunology question
|
|
||||||
def test_immunology_response(biogpt_instance):
|
|
||||||
question = "What is the role of T cells in the immune response?"
|
|
||||||
response = biogpt_instance(question)
|
|
||||||
assert response
|
|
||||||
assert isinstance(response, str)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init(bio_gpt):
|
|
||||||
assert bio_gpt.model_name == "microsoft/biogpt"
|
|
||||||
assert bio_gpt.max_length == 500
|
|
||||||
assert bio_gpt.num_return_sequences == 5
|
|
||||||
assert bio_gpt.do_sample is True
|
|
||||||
assert bio_gpt.min_length == 100
|
|
||||||
|
|
||||||
|
|
||||||
def test_call(bio_gpt, monkeypatch):
|
|
||||||
def mock_pipeline(*args, **kwargs):
|
|
||||||
class MockGenerator:
|
|
||||||
def __call__(self, text, **kwargs):
|
|
||||||
return ["Generated text"]
|
|
||||||
|
|
||||||
return MockGenerator()
|
|
||||||
|
|
||||||
monkeypatch.setattr("transformers.pipeline", mock_pipeline)
|
|
||||||
result = bio_gpt("Input text")
|
|
||||||
assert result == ["Generated text"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_features(bio_gpt):
|
|
||||||
features = bio_gpt.get_features("Input text")
|
|
||||||
assert "last_hidden_state" in features
|
|
||||||
|
|
||||||
|
|
||||||
def test_beam_search_decoding(bio_gpt):
|
|
||||||
generated_text = bio_gpt.beam_search_decoding("Input text")
|
|
||||||
assert isinstance(generated_text, str)
|
|
||||||
|
|
||||||
|
|
||||||
def test_set_pretrained_model(bio_gpt):
|
|
||||||
bio_gpt.set_pretrained_model("new_model")
|
|
||||||
assert bio_gpt.model_name == "new_model"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_config(bio_gpt):
|
|
||||||
config = bio_gpt.get_config()
|
|
||||||
assert "vocab_size" in config
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_model(tmp_path, bio_gpt):
|
|
||||||
bio_gpt.save_model(tmp_path)
|
|
||||||
bio_gpt.load_from_path(tmp_path)
|
|
||||||
assert bio_gpt.model_name == "microsoft/biogpt"
|
|
||||||
|
|
||||||
|
|
||||||
def test_print_model(capsys, bio_gpt):
|
|
||||||
bio_gpt.print_model()
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert "BioGptForCausalLM" in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
# 26. Test if set_pretrained_model changes the model_name
|
|
||||||
def test_set_pretrained_model_name_change(biogpt_instance):
|
|
||||||
biogpt_instance.set_pretrained_model("new_model_name")
|
|
||||||
assert biogpt_instance.model_name == "new_model_name"
|
|
||||||
|
|
||||||
|
|
||||||
# 27. Test get_config return type
|
|
||||||
def test_get_config_return_type(biogpt_instance):
|
|
||||||
config = biogpt_instance.get_config()
|
|
||||||
assert isinstance(config, type(biogpt_instance.model.config))
|
|
||||||
|
|
||||||
|
|
||||||
# 28. Test saving model functionality by checking if files are created
|
|
||||||
@patch.object(BioGptForCausalLM, "save_pretrained")
|
|
||||||
@patch.object(BioGptTokenizer, "save_pretrained")
|
|
||||||
def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance):
|
|
||||||
path = "test_path"
|
|
||||||
biogpt_instance.save_model(path)
|
|
||||||
mock_save_model.assert_called_once_with(path)
|
|
||||||
mock_save_tokenizer.assert_called_once_with(path)
|
|
||||||
|
|
||||||
|
|
||||||
# 29. Test loading model from path
|
|
||||||
@patch.object(BioGptForCausalLM, "from_pretrained")
|
|
||||||
@patch.object(BioGptTokenizer, "from_pretrained")
|
|
||||||
def test_load_from_path(
|
|
||||||
mock_load_model, mock_load_tokenizer, biogpt_instance
|
|
||||||
):
|
|
||||||
path = "test_path"
|
|
||||||
biogpt_instance.load_from_path(path)
|
|
||||||
mock_load_model.assert_called_once_with(path)
|
|
||||||
mock_load_tokenizer.assert_called_once_with(path)
|
|
||||||
|
|
||||||
|
|
||||||
# 30. Test print_model doesn't raise any error
|
|
||||||
def test_print_model_metadata(biogpt_instance):
|
|
||||||
try:
|
|
||||||
biogpt_instance.print_model()
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"print_model() raised an exception: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# 31. Test that beam_search_decoding uses the correct number of beams
|
|
||||||
@patch.object(BioGptForCausalLM, "generate")
|
|
||||||
def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance):
|
|
||||||
biogpt_instance.beam_search_decoding("test_sentence", num_beams=7)
|
|
||||||
_, kwargs = mock_generate.call_args
|
|
||||||
assert kwargs["num_beams"] == 7
|
|
||||||
|
|
||||||
|
|
||||||
# 32. Test if beam_search_decoding handles early_stopping
|
|
||||||
@patch.object(BioGptForCausalLM, "generate")
|
|
||||||
def test_beam_search_decoding_early_stopping(
|
|
||||||
mock_generate, biogpt_instance
|
|
||||||
):
|
|
||||||
biogpt_instance.beam_search_decoding(
|
|
||||||
"test_sentence", early_stopping=False
|
|
||||||
)
|
|
||||||
_, kwargs = mock_generate.call_args
|
|
||||||
assert kwargs["early_stopping"] is False
|
|
||||||
|
|
||||||
|
|
||||||
# 33. Test get_features return type
|
|
||||||
def test_get_features_return_type(biogpt_instance):
|
|
||||||
result = biogpt_instance.get_features("This is a sample text.")
|
|
||||||
assert isinstance(result, torch.nn.modules.module.Module)
|
|
||||||
|
|
||||||
|
|
||||||
# 34. Test if default model is set correctly during initialization
|
|
||||||
def test_default_model_name(biogpt_instance):
|
|
||||||
assert biogpt_instance.model_name == "microsoft/biogpt"
|
|
@ -1,96 +0,0 @@
|
|||||||
import os
|
|
||||||
from unittest.mock import mock_open, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from swarms.models.eleven_labs import (
|
|
||||||
ElevenLabsModel,
|
|
||||||
ElevenLabsText2SpeechTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Define some test data
|
|
||||||
SAMPLE_TEXT = "Hello, this is a test."
|
|
||||||
API_KEY = os.environ.get("ELEVEN_API_KEY")
|
|
||||||
EXPECTED_SPEECH_FILE = "expected_speech.wav"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def eleven_labs_tool():
|
|
||||||
return ElevenLabsText2SpeechTool()
|
|
||||||
|
|
||||||
|
|
||||||
# Basic functionality tests
|
|
||||||
def test_run_text_to_speech(eleven_labs_tool):
|
|
||||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
|
||||||
assert isinstance(speech_file, str)
|
|
||||||
assert speech_file.endswith(".wav")
|
|
||||||
|
|
||||||
|
|
||||||
def test_play_speech(eleven_labs_tool):
|
|
||||||
with patch("builtins.open", mock_open(read_data="fake_audio_data")):
|
|
||||||
eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_speech(eleven_labs_tool):
|
|
||||||
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file:
|
|
||||||
eleven_labs_tool.stream_speech(SAMPLE_TEXT)
|
|
||||||
mock_file.assert_called_with(
|
|
||||||
mode="bx", suffix=".wav", delete=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Testing fixture and environment variables
|
|
||||||
def test_api_key_validation(eleven_labs_tool):
|
|
||||||
with patch(
|
|
||||||
"langchain.utils.get_from_dict_or_env", return_value=API_KEY
|
|
||||||
):
|
|
||||||
values = {"eleven_api_key": None}
|
|
||||||
validated_values = eleven_labs_tool.validate_environment(values)
|
|
||||||
assert "eleven_api_key" in validated_values
|
|
||||||
|
|
||||||
|
|
||||||
# Mocking the external library
|
|
||||||
def test_run_text_to_speech_with_mock(eleven_labs_tool):
|
|
||||||
with patch(
|
|
||||||
"tempfile.NamedTemporaryFile", mock_open()
|
|
||||||
) as mock_file, patch(
|
|
||||||
"your_module._import_elevenlabs"
|
|
||||||
) as mock_elevenlabs:
|
|
||||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
|
||||||
mock_elevenlabs_instance.generate.return_value = b"fake_audio_data"
|
|
||||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
|
||||||
assert mock_file.call_args[1]["suffix"] == ".wav"
|
|
||||||
assert mock_file.call_args[1]["delete"] is False
|
|
||||||
assert mock_file().write.call_args[0][0] == b"fake_audio_data"
|
|
||||||
|
|
||||||
|
|
||||||
# Exception testing
|
|
||||||
def test_run_text_to_speech_error_handling(eleven_labs_tool):
|
|
||||||
with patch("your_module._import_elevenlabs") as mock_elevenlabs:
|
|
||||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
|
||||||
mock_elevenlabs_instance.generate.side_effect = Exception(
|
|
||||||
"Test Exception"
|
|
||||||
)
|
|
||||||
with pytest.raises(
|
|
||||||
RuntimeError,
|
|
||||||
match=(
|
|
||||||
"Error while running ElevenLabsText2SpeechTool: Test"
|
|
||||||
" Exception"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
|
||||||
|
|
||||||
|
|
||||||
# Parameterized testing
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model",
|
|
||||||
[ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL],
|
|
||||||
)
|
|
||||||
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model):
|
|
||||||
eleven_labs_tool.model = model
|
|
||||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
|
||||||
assert isinstance(speech_file, str)
|
|
||||||
assert speech_file.endswith(".wav")
|
|
@ -1,177 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from swarms.models.gigabind import Gigabind
|
|
||||||
|
|
||||||
try:
|
|
||||||
import requests_mock
|
|
||||||
except ImportError:
|
|
||||||
requests_mock = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def api():
|
|
||||||
return Gigabind(host="localhost", port=8000, endpoint="embeddings")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock(requests_mock):
|
|
||||||
requests_mock.post(
|
|
||||||
"http://localhost:8000/embeddings", json={"result": "success"}
|
|
||||||
)
|
|
||||||
return requests_mock
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_text(api, mock):
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_vision(api, mock):
|
|
||||||
response = api.run(vision="image.jpg")
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_audio(api, mock):
|
|
||||||
response = api.run(audio="audio.mp3")
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_all(api, mock):
|
|
||||||
response = api.run(
|
|
||||||
text="Hello, world!", vision="image.jpg", audio="audio.mp3"
|
|
||||||
)
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_none(api):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
api.run()
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_summary(api, mock):
|
|
||||||
response = api.generate_summary(text="Hello, world!")
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_summary_with_none(api):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
api.generate_summary()
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_on_failure(api, requests_mock):
|
|
||||||
requests_mock.post(
|
|
||||||
"http://localhost:8000/embeddings",
|
|
||||||
[
|
|
||||||
{"status_code": 500, "json": {}},
|
|
||||||
{"status_code": 500, "json": {}},
|
|
||||||
{"status_code": 200, "json": {"result": "success"}},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_exhausted(api, requests_mock):
|
|
||||||
requests_mock.post(
|
|
||||||
"http://localhost:8000/embeddings",
|
|
||||||
[
|
|
||||||
{"status_code": 500, "json": {}},
|
|
||||||
{"status_code": 500, "json": {}},
|
|
||||||
{"status_code": 500, "json": {}},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_proxy_url(api):
|
|
||||||
api.proxy_url = "http://proxy:8080"
|
|
||||||
assert api.url == "http://proxy:8080"
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_response(api, requests_mock):
|
|
||||||
requests_mock.post("http://localhost:8000/embeddings", text="not json")
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_connection_error(api, requests_mock):
|
|
||||||
requests_mock.post(
|
|
||||||
"http://localhost:8000/embeddings",
|
|
||||||
exc=requests.exceptions.ConnectTimeout,
|
|
||||||
)
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_http_error(api, requests_mock):
|
|
||||||
requests_mock.post("http://localhost:8000/embeddings", status_code=500)
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_url_construction(api):
|
|
||||||
assert api.url == "http://localhost:8000/embeddings"
|
|
||||||
|
|
||||||
|
|
||||||
def test_url_construction_with_proxy(api):
|
|
||||||
api.proxy_url = "http://proxy:8080"
|
|
||||||
assert api.url == "http://proxy:8080"
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_large_text(api, mock):
|
|
||||||
large_text = "Hello, world! " * 10000 # 10,000 repetitions
|
|
||||||
response = api.run(text=large_text)
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_large_vision(api, mock):
|
|
||||||
large_vision = "image.jpg" * 10000 # 10,000 repetitions
|
|
||||||
response = api.run(vision=large_vision)
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_large_audio(api, mock):
|
|
||||||
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
|
|
||||||
response = api.run(audio=large_audio)
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_large_all(api, mock):
|
|
||||||
large_text = "Hello, world! " * 10000 # 10,000 repetitions
|
|
||||||
large_vision = "image.jpg" * 10000 # 10,000 repetitions
|
|
||||||
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
|
|
||||||
response = api.run(
|
|
||||||
text=large_text, vision=large_vision, audio=large_audio
|
|
||||||
)
|
|
||||||
assert response == {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_timeout(api, mock):
|
|
||||||
response = api.run(text="Hello, world!", timeout=0.001)
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_invalid_host(api):
|
|
||||||
api.host = "invalid"
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_invalid_port(api):
|
|
||||||
api.port = 99999
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_invalid_endpoint(api):
|
|
||||||
api.endpoint = "invalid"
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_invalid_proxy_url(api):
|
|
||||||
api.proxy_url = "invalid"
|
|
||||||
response = api.run(text="Hello, world!")
|
|
||||||
assert response is None
|
|
@ -1,237 +0,0 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from swarms.models.huggingface import (
|
|
||||||
HuggingfaceLLM, # Replace with the actual import path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Fixture for the class instance
|
|
||||||
@pytest.fixture
|
|
||||||
def llm_instance():
|
|
||||||
model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
|
||||||
instance = HuggingfaceLLM(model_id=model_id)
|
|
||||||
return instance
|
|
||||||
|
|
||||||
|
|
||||||
# Test for instantiation and attributes
|
|
||||||
def test_llm_initialization(llm_instance):
|
|
||||||
assert (
|
|
||||||
llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
|
||||||
)
|
|
||||||
assert llm_instance.max_length == 500
|
|
||||||
# ... add more assertions for all default attributes
|
|
||||||
|
|
||||||
|
|
||||||
# Parameterized test for setting devices
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
|
||||||
def test_llm_set_device(llm_instance, device):
|
|
||||||
llm_instance.set_device(device)
|
|
||||||
assert llm_instance.device == device
|
|
||||||
|
|
||||||
|
|
||||||
# Test exception during initialization with a bad model_id
|
|
||||||
def test_llm_bad_model_initialization():
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
HuggingfaceLLM(model_id="unknown-model")
|
|
||||||
|
|
||||||
|
|
||||||
# # Mocking the tokenizer and model to test run method
|
|
||||||
# @patch("swarms.models.huggingface.AutoTokenizer.from_pretrained")
|
|
||||||
# @patch(
|
|
||||||
# "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained"
|
|
||||||
# )
|
|
||||||
# def test_llm_run(mock_model, mock_tokenizer, llm_instance):
|
|
||||||
# mock_model.return_value.generate.return_value = "mocked output"
|
|
||||||
# mock_tokenizer.return_value.encode.return_value = "mocked input"
|
|
||||||
# result = llm_instance.run("test task")
|
|
||||||
# assert result == "mocked output"
|
|
||||||
|
|
||||||
|
|
||||||
# Async test (requires pytest-asyncio plugin)
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_llm_run_async(llm_instance):
|
|
||||||
result = await llm_instance.run_async("test task")
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
# Test for checking GPU availability
|
|
||||||
def test_llm_gpu_availability(llm_instance):
|
|
||||||
# Assuming the test is running on a machine where the GPU availability is known
|
|
||||||
expected_result = torch.cuda.is_available()
|
|
||||||
assert llm_instance.gpu_available() == expected_result
|
|
||||||
|
|
||||||
|
|
||||||
# Test for memory consumption reporting
|
|
||||||
def test_llm_memory_consumption(llm_instance):
|
|
||||||
# Mocking torch.cuda functions for consistent results
|
|
||||||
with patch("torch.cuda.memory_allocated", return_value=1024):
|
|
||||||
with patch("torch.cuda.memory_reserved", return_value=2048):
|
|
||||||
memory = llm_instance.memory_consumption()
|
|
||||||
assert memory == {"allocated": 1024, "reserved": 2048}
|
|
||||||
|
|
||||||
|
|
||||||
# Test different initialization parameters
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_id, max_length",
|
|
||||||
[
|
|
||||||
("NousResearch/Nous-Hermes-2-Vision-Alpha", 100),
|
|
||||||
("microsoft/Orca-2-13b", 200),
|
|
||||||
(
|
|
||||||
"berkeley-nest/Starling-LM-7B-alpha",
|
|
||||||
None,
|
|
||||||
), # None to check default behavior
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_llm_initialization_params(model_id, max_length):
|
|
||||||
if max_length:
|
|
||||||
instance = HuggingfaceLLM(model_id=model_id, max_length=max_length)
|
|
||||||
assert instance.max_length == max_length
|
|
||||||
else:
|
|
||||||
instance = HuggingfaceLLM(model_id=model_id)
|
|
||||||
assert (
|
|
||||||
instance.max_length == 500
|
|
||||||
) # Assuming 500 is the default max_length
|
|
||||||
|
|
||||||
|
|
||||||
# Test for setting an invalid device
|
|
||||||
def test_llm_set_invalid_device(llm_instance):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
llm_instance.set_device("quantum_processor")
|
|
||||||
|
|
||||||
|
|
||||||
# Mocking external API call to test run method without network
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_run_without_network(mock_run, llm_instance):
|
|
||||||
mock_run.return_value = "mocked output"
|
|
||||||
result = llm_instance.run("test task without network")
|
|
||||||
assert result == "mocked output"
|
|
||||||
|
|
||||||
|
|
||||||
# Test handling of empty input for the run method
|
|
||||||
def test_llm_run_empty_input(llm_instance):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
llm_instance.run("")
|
|
||||||
|
|
||||||
|
|
||||||
# Test the generation with a provided seed for reproducibility
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_run_with_seed(mock_run, llm_instance):
|
|
||||||
seed = 42
|
|
||||||
llm_instance.set_seed(seed)
|
|
||||||
# Assuming set_seed method affects the randomness in the model
|
|
||||||
# You would typically ensure that setting the seed gives reproducible results
|
|
||||||
mock_run.return_value = "mocked deterministic output"
|
|
||||||
result = llm_instance.run("test task", seed=seed)
|
|
||||||
assert result == "mocked deterministic output"
|
|
||||||
|
|
||||||
|
|
||||||
# Test the output length is as expected
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_run_output_length(mock_run, llm_instance):
|
|
||||||
input_text = "test task"
|
|
||||||
llm_instance.max_length = 50 # set a max_length for the output
|
|
||||||
mock_run.return_value = "mocked output" * 10 # some long text
|
|
||||||
result = llm_instance.run(input_text)
|
|
||||||
assert len(result.split()) <= llm_instance.max_length
|
|
||||||
|
|
||||||
|
|
||||||
# Test the tokenizer handling special tokens correctly
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode")
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode")
|
|
||||||
def test_llm_tokenizer_special_tokens(
|
|
||||||
mock_decode, mock_encode, llm_instance
|
|
||||||
):
|
|
||||||
mock_encode.return_value = "encoded input with special tokens"
|
|
||||||
mock_decode.return_value = "decoded output with special tokens"
|
|
||||||
result = llm_instance.run("test task with special tokens")
|
|
||||||
mock_encode.assert_called_once()
|
|
||||||
mock_decode.assert_called_once()
|
|
||||||
assert "special tokens" in result
|
|
||||||
|
|
||||||
|
|
||||||
# Test for correct handling of timeouts
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_timeout_handling(mock_run, llm_instance):
|
|
||||||
mock_run.side_effect = TimeoutError
|
|
||||||
with pytest.raises(TimeoutError):
|
|
||||||
llm_instance.run("test task with timeout")
|
|
||||||
|
|
||||||
|
|
||||||
# Test for response time within a threshold (performance test)
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_response_time(mock_run, llm_instance):
|
|
||||||
import time
|
|
||||||
|
|
||||||
mock_run.return_value = "mocked output"
|
|
||||||
start_time = time.time()
|
|
||||||
llm_instance.run("test task for response time")
|
|
||||||
end_time = time.time()
|
|
||||||
assert (
|
|
||||||
end_time - start_time < 1
|
|
||||||
) # Assuming the response should be faster than 1 second
|
|
||||||
|
|
||||||
|
|
||||||
# Test the logging of a warning for long inputs
|
|
||||||
@patch("swarms.models.huggingface.logging.warning")
|
|
||||||
def test_llm_long_input_warning(mock_warning, llm_instance):
|
|
||||||
long_input = "x" * 10000 # input longer than the typical limit
|
|
||||||
llm_instance.run(long_input)
|
|
||||||
mock_warning.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
# Test for run method behavior when model raises an exception
|
|
||||||
@patch(
|
|
||||||
"swarms.models.huggingface.HuggingfaceLLM._model.generate",
|
|
||||||
side_effect=RuntimeError,
|
|
||||||
)
|
|
||||||
def test_llm_run_model_exception(mock_generate, llm_instance):
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
llm_instance.run("test task when model fails")
|
|
||||||
|
|
||||||
|
|
||||||
# Test the behavior when GPU is forced but not available
|
|
||||||
@patch("torch.cuda.is_available", return_value=False)
|
|
||||||
def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance):
|
|
||||||
with pytest.raises(EnvironmentError):
|
|
||||||
llm_instance.set_device(
|
|
||||||
"cuda"
|
|
||||||
) # Attempt to set CUDA when it's not available
|
|
||||||
|
|
||||||
|
|
||||||
# Test for proper cleanup after model use (releasing resources)
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM._model")
|
|
||||||
def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance):
|
|
||||||
llm_instance.cleanup()
|
|
||||||
# Assuming cleanup method is meant to free resources
|
|
||||||
mock_model.delete.assert_called_once()
|
|
||||||
mock_tokenizer.delete.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
# Test model's ability to handle multilingual input
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_multilingual_input(mock_run, llm_instance):
|
|
||||||
mock_run.return_value = "mocked multilingual output"
|
|
||||||
multilingual_input = "Bonjour, ceci est un test multilingue."
|
|
||||||
result = llm_instance.run(multilingual_input)
|
|
||||||
assert isinstance(
|
|
||||||
result, str
|
|
||||||
) # Simple check to ensure output is string type
|
|
||||||
|
|
||||||
|
|
||||||
# Test caching mechanism to prevent re-running the same inputs
|
|
||||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
||||||
def test_llm_caching_mechanism(mock_run, llm_instance):
|
|
||||||
input_text = "test caching mechanism"
|
|
||||||
mock_run.return_value = "cached output"
|
|
||||||
# Run the input twice
|
|
||||||
first_run_result = llm_instance.run(input_text)
|
|
||||||
second_run_result = llm_instance.run(input_text)
|
|
||||||
mock_run.assert_called_once() # Should only be called once due to caching
|
|
||||||
assert first_run_result == second_run_result
|
|
||||||
|
|
||||||
|
|
||||||
# These tests are provided as examples. In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class.
|
|
||||||
# For instance, "mock_model.delete.assert_called_once()" and similar lines are based on hypothetical methods and behaviors that you need to replace with actual implementations.
|
|
@ -1,84 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from swarms.models.jina_embeds import JinaEmbeddings
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model():
|
|
||||||
return JinaEmbeddings("bert-base-uncased", verbose=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialization(model):
|
|
||||||
assert isinstance(model, JinaEmbeddings)
|
|
||||||
assert model.device in ["cuda", "cpu"]
|
|
||||||
assert model.max_length == 500
|
|
||||||
assert model.verbose is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_sync(model):
|
|
||||||
task = "Encode this text"
|
|
||||||
result = model.run(task)
|
|
||||||
assert isinstance(result, torch.Tensor)
|
|
||||||
assert result.shape == (model.max_length,)
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_async(model):
|
|
||||||
task = "Encode this text"
|
|
||||||
result = model.run_async(task)
|
|
||||||
assert isinstance(result, torch.Tensor)
|
|
||||||
assert result.shape == (model.max_length,)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_model(tmp_path, model):
|
|
||||||
model_path = tmp_path / "model"
|
|
||||||
model.save_model(model_path)
|
|
||||||
assert (model_path / "config.json").is_file()
|
|
||||||
assert (model_path / "pytorch_model.bin").is_file()
|
|
||||||
assert (model_path / "vocab.txt").is_file()
|
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_available(model):
|
|
||||||
gpu_status = model.gpu_available()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
assert gpu_status is True
|
|
||||||
else:
|
|
||||||
assert gpu_status is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_consumption(model):
|
|
||||||
memory_stats = model.memory_consumption()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
assert "allocated" in memory_stats
|
|
||||||
assert "reserved" in memory_stats
|
|
||||||
else:
|
|
||||||
assert "error" in memory_stats
|
|
||||||
|
|
||||||
|
|
||||||
def test_cosine_similarity(model):
|
|
||||||
task1 = "This is a sample text for testing."
|
|
||||||
task2 = "Another sample text for testing."
|
|
||||||
embeddings1 = model.run(task1)
|
|
||||||
embeddings2 = model.run(task2)
|
|
||||||
sim = model.cos_sim(embeddings1, embeddings2)
|
|
||||||
assert isinstance(sim, torch.Tensor)
|
|
||||||
assert sim.item() >= -1.0
|
|
||||||
assert sim.item() <= 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_failed_load_model(caplog):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
JinaEmbeddings("invalid_model")
|
|
||||||
assert "Failed to load the model or the tokenizer" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_failed_generate_text(caplog, model):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
model.run("invalid_task")
|
|
||||||
assert "Failed to generate the text" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
|
||||||
def test_change_device(model, device):
|
|
||||||
model.device = device
|
|
||||||
assert model.device == device
|
|
@ -1,43 +0,0 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from swarms.models import TimmModel
|
|
||||||
|
|
||||||
|
|
||||||
def test_timm_model_init():
|
|
||||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
|
||||||
model_name = "resnet18"
|
|
||||||
pretrained = True
|
|
||||||
in_chans = 3
|
|
||||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
|
||||||
mock_list_models.assert_called_once()
|
|
||||||
assert timm_model.model_name == model_name
|
|
||||||
assert timm_model.pretrained == pretrained
|
|
||||||
assert timm_model.in_chans == in_chans
|
|
||||||
assert timm_model.models == mock_list_models.return_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_timm_model_call():
|
|
||||||
with patch("swarms.models.timm.create_model") as mock_create_model:
|
|
||||||
model_name = "resnet18"
|
|
||||||
pretrained = True
|
|
||||||
in_chans = 3
|
|
||||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
|
||||||
task = torch.rand(1, in_chans, 224, 224)
|
|
||||||
result = timm_model(task)
|
|
||||||
mock_create_model.assert_called_once_with(
|
|
||||||
model_name, pretrained=pretrained, in_chans=in_chans
|
|
||||||
)
|
|
||||||
assert result == mock_create_model.return_value(task)
|
|
||||||
|
|
||||||
|
|
||||||
def test_timm_model_list_models():
|
|
||||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
|
||||||
model_name = "resnet18"
|
|
||||||
pretrained = True
|
|
||||||
in_chans = 3
|
|
||||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
|
||||||
result = timm_model.list_models()
|
|
||||||
mock_list_models.assert_called_once()
|
|
||||||
assert result == mock_list_models.return_value
|
|
@ -1,35 +0,0 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from swarms.models.ultralytics_model import UltralyticsModel
|
|
||||||
|
|
||||||
|
|
||||||
def test_ultralytics_init():
|
|
||||||
with patch("swarms.models.YOLO") as mock_yolo:
|
|
||||||
model_name = "yolov5s"
|
|
||||||
ultralytics = UltralyticsModel(model_name)
|
|
||||||
mock_yolo.assert_called_once_with(model_name)
|
|
||||||
assert ultralytics.model_name == model_name
|
|
||||||
assert ultralytics.model == mock_yolo.return_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_ultralytics_call():
|
|
||||||
with patch("swarms.models.YOLO") as mock_yolo:
|
|
||||||
model_name = "yolov5s"
|
|
||||||
ultralytics = UltralyticsModel(model_name)
|
|
||||||
task = "detect"
|
|
||||||
args = (1, 2, 3)
|
|
||||||
kwargs = {"a": "A", "b": "B"}
|
|
||||||
result = ultralytics(task, *args, **kwargs)
|
|
||||||
mock_yolo.return_value.assert_called_once_with(
|
|
||||||
task, *args, **kwargs
|
|
||||||
)
|
|
||||||
assert result == mock_yolo.return_value.return_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_ultralytics_list_models():
|
|
||||||
with patch("swarms.models.YOLO") as mock_yolo:
|
|
||||||
model_name = "yolov5s"
|
|
||||||
ultralytics = UltralyticsModel(model_name)
|
|
||||||
result = ultralytics.list_models()
|
|
||||||
mock_yolo.list_models.assert_called_once()
|
|
||||||
assert result == mock_yolo.list_models.return_value
|
|
@ -1,126 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from swarms.models.yi_200k import Yi34B200k
|
|
||||||
|
|
||||||
|
|
||||||
# Create fixtures if needed
|
|
||||||
@pytest.fixture
|
|
||||||
def yi34b_model():
|
|
||||||
return Yi34B200k()
|
|
||||||
|
|
||||||
|
|
||||||
# Test cases for the Yi34B200k class
|
|
||||||
def test_yi34b_init(yi34b_model):
|
|
||||||
assert isinstance(yi34b_model.model, torch.nn.Module)
|
|
||||||
assert isinstance(yi34b_model.tokenizer, AutoTokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text(yi34b_model):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
generated_text = yi34b_model(prompt)
|
|
||||||
assert isinstance(generated_text, str)
|
|
||||||
assert len(generated_text) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("max_length", [64, 128, 256, 512])
|
|
||||||
def test_yi34b_generate_text_with_length(yi34b_model, max_length):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
generated_text = yi34b_model(prompt, max_length=max_length)
|
|
||||||
assert len(generated_text) <= max_length
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
|
|
||||||
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
generated_text = yi34b_model(prompt, temperature=temperature)
|
|
||||||
assert isinstance(generated_text, str)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
|
|
||||||
prompt = None # Invalid prompt
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="Input prompt must be a non-empty string"
|
|
||||||
):
|
|
||||||
yi34b_model(prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
max_length = -1 # Invalid max_length
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="max_length must be a positive integer"
|
|
||||||
):
|
|
||||||
yi34b_model(prompt, max_length=max_length)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
temperature = 2.0 # Invalid temperature
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="temperature must be between 0.01 and 1.0"
|
|
||||||
):
|
|
||||||
yi34b_model(prompt, temperature=temperature)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("top_k", [20, 30, 50])
|
|
||||||
def test_yi34b_generate_text_with_top_k(yi34b_model, top_k):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
generated_text = yi34b_model(prompt, top_k=top_k)
|
|
||||||
assert isinstance(generated_text, str)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9])
|
|
||||||
def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
generated_text = yi34b_model(prompt, top_p=top_p)
|
|
||||||
assert isinstance(generated_text, str)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
top_k = -1 # Invalid top_k
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="top_k must be a non-negative integer"
|
|
||||||
):
|
|
||||||
yi34b_model(prompt, top_k=top_k)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
top_p = 1.5 # Invalid top_p
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="top_p must be between 0.0 and 1.0"
|
|
||||||
):
|
|
||||||
yi34b_model(prompt, top_p=top_p)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
|
|
||||||
def test_yi34b_generate_text_with_repitition_penalty(
|
|
||||||
yi34b_model, repitition_penalty
|
|
||||||
):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
generated_text = yi34b_model(
|
|
||||||
prompt, repitition_penalty=repitition_penalty
|
|
||||||
)
|
|
||||||
assert isinstance(generated_text, str)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_repitition_penalty(
|
|
||||||
yi34b_model,
|
|
||||||
):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
repitition_penalty = 0.0 # Invalid repitition_penalty
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="repitition_penalty must be a positive float",
|
|
||||||
):
|
|
||||||
yi34b_model(prompt, repitition_penalty=repitition_penalty)
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi34b_generate_text_with_invalid_device(yi34b_model):
|
|
||||||
prompt = "There's a place where time stands still."
|
|
||||||
device_map = "invalid_device" # Invalid device_map
|
|
||||||
with pytest.raises(ValueError, match="Invalid device_map"):
|
|
||||||
yi34b_model(prompt, device_map=device_map)
|
|
Loading…
Reference in new issue