parent
f93bc98952
commit
2f88e92930
@ -1,80 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||
from swarms.models.sam_supervision import SegmentAnythingMarkGenerator
|
||||
from swarms.utils.supervision_masking import refine_marks
|
||||
from swarms.utils.supervision_visualizer import MarkVisualizer
|
||||
|
||||
|
||||
class GPT4VSAM(BaseMultiModalModel):
|
||||
"""
|
||||
GPT4VSAM class represents a multi-modal model that combines the capabilities of GPT-4 and SegmentAnythingMarkGenerator.
|
||||
It takes an instance of BaseMultiModalModel (vlm) and a device as input and provides methods for loading images and making predictions.
|
||||
|
||||
Args:
|
||||
vlm (BaseMultiModalModel): An instance of BaseMultiModalModel representing the visual language model.
|
||||
device (str, optional): The device to be used for computation. Defaults to "cuda".
|
||||
|
||||
Attributes:
|
||||
vlm (BaseMultiModalModel): An instance of BaseMultiModalModel representing the visual language model.
|
||||
device (str): The device to be used for computation.
|
||||
sam (SegmentAnythingMarkGenerator): An instance of SegmentAnythingMarkGenerator for generating marks.
|
||||
visualizer (MarkVisualizer): An instance of MarkVisualizer for visualizing marks.
|
||||
|
||||
Methods:
|
||||
load_img(img: str) -> Any: Loads an image from the given file path.
|
||||
__call__(task: str, img: str, *args, **kwargs) -> Any: Makes predictions using the visual language model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vlm: BaseMultiModalModel,
|
||||
device: str = "cuda",
|
||||
return_related_marks: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vlm = vlm
|
||||
self.device = device
|
||||
self.return_related_marks = return_related_marks
|
||||
|
||||
self.sam = SegmentAnythingMarkGenerator(device, *args, **kwargs)
|
||||
self.visualizer = MarkVisualizer(*args, **kwargs)
|
||||
|
||||
def load_img(self, img: str) -> Any:
|
||||
"""
|
||||
Loads an image from the given file path.
|
||||
|
||||
Args:
|
||||
img (str): The file path of the image.
|
||||
|
||||
Returns:
|
||||
Any: The loaded image.
|
||||
|
||||
"""
|
||||
return cv2.imread(img)
|
||||
|
||||
def __call__(self, task: str, img: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Makes predictions using the visual language model.
|
||||
|
||||
Args:
|
||||
task (str): The task for which predictions are to be made.
|
||||
img (str): The file path of the image.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Any: The predictions made by the visual language model.
|
||||
|
||||
"""
|
||||
img = self.load_img(img)
|
||||
|
||||
marks = self.sam(image=img)
|
||||
marks = refine_marks(marks=marks)
|
||||
|
||||
return self.vlm(task, img, *args, **kwargs)
|
@ -1,244 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from numpy.linalg import norm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from swarms.models.base_embedding_model import BaseEmbeddingModel
|
||||
|
||||
|
||||
def cos_sim(a, b):
|
||||
return a @ b.T / (norm(a) * norm(b))
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseEmbeddingModel):
|
||||
"""
|
||||
Jina Embeddings model.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id to use. Default is "jinaai/jina-embeddings-v2-base-en".
|
||||
device (str): The device to run the model on. Default is "cuda".
|
||||
huggingface_api_key (str): The Hugging Face API key. Default is None.
|
||||
max_length (int): The maximum length of the response. Default is 500.
|
||||
quantize (bool): Whether to quantize the model. Default is False.
|
||||
quantization_config (dict): The quantization configuration. Default is None.
|
||||
verbose (bool): Whether to print verbose logs. Default is False.
|
||||
distributed (bool): Whether to use distributed processing. Default is False.
|
||||
decoding (bool): Whether to use decoding. Default is False.
|
||||
cos_sim (callable): The cosine similarity function. Default is cos_sim.
|
||||
|
||||
Methods:
|
||||
run: _description_
|
||||
|
||||
Examples:
|
||||
>>> model = JinaEmbeddings(
|
||||
>>> max_length=8192,
|
||||
>>> device="cuda",
|
||||
>>> quantize=True,
|
||||
>>> huggingface_api_key="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"
|
||||
>>> )
|
||||
>>> embeddings = model("Encode this super long document text")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "jinaai/jina-embeddings-v2-base-en",
|
||||
device: str = None,
|
||||
huggingface_api_key: str = None,
|
||||
max_length: int = 500,
|
||||
quantize: bool = False,
|
||||
quantization_config: dict = None,
|
||||
verbose=False,
|
||||
distributed=False,
|
||||
decoding=False,
|
||||
cos_sim: callable = cos_sim,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.device = (
|
||||
device
|
||||
if device
|
||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
self.huggingface_api_key = huggingface_api_key
|
||||
self.model_id = model_id
|
||||
self.max_length = max_length
|
||||
self.verbose = verbose
|
||||
self.distributed = distributed
|
||||
self.decoding = decoding
|
||||
self.model, self.tokenizer = None, None
|
||||
self.cos_sim = cos_sim
|
||||
|
||||
if self.distributed:
|
||||
assert (
|
||||
torch.cuda.device_count() > 1
|
||||
), "You need more than 1 gpu for distributed processing"
|
||||
|
||||
# If API key then set it
|
||||
if self.huggingface_api_key:
|
||||
os.environ["HF_TOKEN"] = self.huggingface_api_key
|
||||
|
||||
bnb_config = None
|
||||
if quantize:
|
||||
if not quantization_config:
|
||||
quantization_config = {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_use_double_quant": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
}
|
||||
bnb_config = BitsAndBytesConfig(**quantization_config)
|
||||
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id,
|
||||
quantization_config=bnb_config,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.model # .to(self.device)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to load the model or the tokenizer: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
"""Load the model"""
|
||||
if not self.model or not self.tokenizer:
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_id
|
||||
)
|
||||
|
||||
bnb_config = (
|
||||
BitsAndBytesConfig(**self.quantization_config)
|
||||
if self.quantization_config
|
||||
else None
|
||||
)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id,
|
||||
quantization_config=bnb_config,
|
||||
trust_remote_code=True,
|
||||
).to(self.device)
|
||||
|
||||
if self.distributed:
|
||||
self.model = DDP(self.model)
|
||||
except Exception as error:
|
||||
self.logger.error(
|
||||
"Failed to load the model or the tokenizer:"
|
||||
f" {error}"
|
||||
)
|
||||
raise
|
||||
|
||||
def run(self, task: str, *args, **kwargs):
|
||||
"""
|
||||
Generate a response based on the prompt text.
|
||||
|
||||
Args:
|
||||
- task (str): Text to prompt the model.
|
||||
- max_length (int): Maximum length of the response.
|
||||
|
||||
Returns:
|
||||
- Generated text (str).
|
||||
"""
|
||||
|
||||
max_length = self.max_length
|
||||
|
||||
try:
|
||||
embeddings = self.model.encode(
|
||||
[task], max_length=max_length, *args, **kwargs
|
||||
)
|
||||
|
||||
if self.cos_sim:
|
||||
print(cos_sim(embeddings[0], embeddings[1]))
|
||||
else:
|
||||
return embeddings[0]
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate the text: {e}")
|
||||
raise
|
||||
|
||||
async def run_async(self, task: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
Run the model asynchronously
|
||||
|
||||
Args:
|
||||
task (str): Task to run.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
|
||||
Examples:
|
||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
||||
'Once upon a time in a land far, far away...'
|
||||
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
|
||||
['In the deep jungles,',
|
||||
'At the heart of the city,']
|
||||
>>> mpt_instance.freeze_model()
|
||||
>>> mpt_instance.unfreeze_model()
|
||||
|
||||
"""
|
||||
# Wrapping synchronous calls with async
|
||||
return self.run(task, *args, **kwargs)
|
||||
|
||||
def __call__(self, task: str, *args, **kwargs):
|
||||
"""
|
||||
Generate a response based on the prompt text.
|
||||
|
||||
Args:
|
||||
- task (str): Text to prompt the model.
|
||||
- max_length (int): Maximum length of the response.
|
||||
|
||||
Returns:
|
||||
- Generated text (str).
|
||||
"""
|
||||
self.load_model()
|
||||
|
||||
max_length = self.max_length
|
||||
|
||||
try:
|
||||
embeddings = self.model.encode(
|
||||
[task], max_length=max_length, *args, **kwargs
|
||||
)
|
||||
|
||||
if self.cos_sim:
|
||||
print(cos_sim(embeddings[0], embeddings[1]))
|
||||
else:
|
||||
return embeddings[0]
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate the text: {e}")
|
||||
raise
|
||||
|
||||
async def __call_async__(self, task: str, *args, **kwargs) -> str:
|
||||
"""Call the model asynchronously""" ""
|
||||
return await self.run_async(task, *args, **kwargs)
|
||||
|
||||
def save_model(self, path: str):
|
||||
"""Save the model to a given path"""
|
||||
self.model.save_pretrained(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
def gpu_available(self) -> bool:
|
||||
"""Check if GPU is available"""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
def memory_consumption(self) -> dict:
|
||||
"""Get the memory consumption of the GPU"""
|
||||
if self.gpu_available():
|
||||
torch.cuda.synchronize()
|
||||
allocated = torch.cuda.memory_allocated()
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
return {"allocated": allocated, "reserved": reserved}
|
||||
else:
|
||||
return {"error": "GPU not available"}
|
||||
|
||||
def try_embed_chunk(self, chunk: str) -> list[float]:
|
||||
return super().try_embed_chunk(chunk)
|
@ -1,142 +0,0 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from skimage import transform
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def sam_model_registry():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MedicalSAM:
|
||||
"""
|
||||
MedicalSAM class for performing semantic segmentation on medical images using the SAM model.
|
||||
|
||||
Attributes:
|
||||
model_path (str): The file path to the model weights.
|
||||
device (str): The device to run the model on (default is "cuda:0").
|
||||
model_weights_url (str): The URL to download the model weights from.
|
||||
|
||||
Methods:
|
||||
__post_init__(): Initializes the MedicalSAM object.
|
||||
download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path.
|
||||
preprocess(img): Preprocesses the input image.
|
||||
run(img, box): Runs the semantic segmentation on the input image within the specified bounding box.
|
||||
|
||||
"""
|
||||
|
||||
model_path: str
|
||||
device: str = "cuda:0"
|
||||
model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||||
|
||||
def __post_init__(self):
|
||||
if not os.path.exists(self.model_path):
|
||||
self.download_model_weights(self.model_path)
|
||||
|
||||
self.model = sam_model_registry["vit_b"](
|
||||
checkpoint=self.model_path
|
||||
)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def download_model_weights(self, model_path: str):
|
||||
"""
|
||||
Downloads the model weights from the specified URL and saves them to the given file path.
|
||||
|
||||
Args:
|
||||
model_path (str): The file path where the model weights will be saved.
|
||||
|
||||
Raises:
|
||||
Exception: If the model weights fail to download.
|
||||
"""
|
||||
response = requests.get(self.model_weights_url, stream=True)
|
||||
if response.status_code == 200:
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
raise Exception("Failed to download model weights.")
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]:
|
||||
"""
|
||||
Preprocesses the input image.
|
||||
|
||||
Args:
|
||||
img: The input image.
|
||||
|
||||
Returns:
|
||||
img_tensor: The preprocessed image tensor.
|
||||
H: The original height of the image.
|
||||
W: The original width of the image.
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = np.repeat(img[:, :, None], 3, axis=-1)
|
||||
H, W, _ = img.shape
|
||||
img = transform.resize(
|
||||
img,
|
||||
(1024, 1024),
|
||||
order=3,
|
||||
preserve_range=True,
|
||||
anti_aliasing=True,
|
||||
).astype(np.uint8)
|
||||
img = img - img.min() / np.clip(
|
||||
img.max() - img.min(), a_min=1e-8, a_max=None
|
||||
)
|
||||
img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
|
||||
return img, H, W
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Runs the semantic segmentation on the input image within the specified bounding box.
|
||||
|
||||
Args:
|
||||
img: The input image.
|
||||
box: The bounding box coordinates (x1, y1, x2, y2).
|
||||
|
||||
Returns:
|
||||
medsam_seg: The segmented image.
|
||||
"""
|
||||
img_tensor, H, W = self.preprocess(img)
|
||||
img_tensor = img_tensor.to(self.device)
|
||||
box_1024 = box / np.array([W, H, W, H]) * 1024
|
||||
img = self.model.image_encoder(img_tensor)
|
||||
|
||||
box_torch = torch.as_tensor(
|
||||
box_1024, dtype=torch.float, device=img_tensor.device
|
||||
)
|
||||
|
||||
if len(box_torch.shape) == 2:
|
||||
box_torch = box_torch[:, None, :]
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=None,
|
||||
boxes=box_torch,
|
||||
masks=None,
|
||||
)
|
||||
|
||||
low_res_logits, _ = self.model.mask_decoder(
|
||||
image_embeddings=img,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
low_res_pred = torch.sigmoid(low_res_logits)
|
||||
low_res_pred = F.interpolate(
|
||||
low_res_pred,
|
||||
size=(H, W),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
low_res_pred = low_res_pred.squeeze().cpu().numpy()
|
||||
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
||||
|
||||
return medsam_seg
|
@ -1,10 +0,0 @@
|
||||
from swarms.models.popular_llms import OpenAIChat
|
||||
|
||||
|
||||
class MistralAPILLM(OpenAIChat):
|
||||
def __init__(self, url):
|
||||
super().__init__()
|
||||
self.openai_proxy_url = url
|
||||
|
||||
def __call__(self, task: str):
|
||||
super().__call__(task)
|
@ -1,553 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Callable, Literal, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import model_validator, ConfigDict, BaseModel, Field
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from swarms.models.embeddings_base import Embeddings
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
values: dict, key: str, env_key: str, default: Any = None
|
||||
) -> Any:
|
||||
import os
|
||||
|
||||
return values.get(key) or os.getenv(env_key) or default
|
||||
|
||||
|
||||
def get_pydantic_field_names(cls: Any) -> set[str]:
|
||||
return set(cls.__annotations__.keys())
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
embeddings: OpenAIEmbeddings,
|
||||
) -> Callable[[Any], Any]:
|
||||
import llm
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(
|
||||
multiplier=1, min=min_seconds, max=max_seconds
|
||||
),
|
||||
retry=(
|
||||
retry_if_exception_type(llm.error.Timeout)
|
||||
| retry_if_exception_type(llm.error.APIError)
|
||||
| retry_if_exception_type(llm.error.APIConnectionError)
|
||||
| retry_if_exception_type(llm.error.RateLimitError)
|
||||
| retry_if_exception_type(llm.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
|
||||
import llm
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
async_retrying = AsyncRetrying(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(
|
||||
multiplier=1, min=min_seconds, max=max_seconds
|
||||
),
|
||||
retry=(
|
||||
retry_if_exception_type(llm.error.Timeout)
|
||||
| retry_if_exception_type(llm.error.APIError)
|
||||
| retry_if_exception_type(llm.error.APIConnectionError)
|
||||
| retry_if_exception_type(llm.error.RateLimitError)
|
||||
| retry_if_exception_type(llm.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
def wrap(func: Callable) -> Callable:
|
||||
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
|
||||
async for _ in async_retrying:
|
||||
return await func(*args, **kwargs)
|
||||
raise AssertionError("this is unreachable")
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
|
||||
def _check_response(response: dict) -> dict:
|
||||
if any(len(d["embedding"]) == 1 for d in response["data"]):
|
||||
import llm
|
||||
|
||||
raise llm.error.APIError("OpenAI API returned an empty embedding")
|
||||
return response
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = embeddings.client.create(**kwargs)
|
||||
return _check_response(response)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def async_embed_with_retry(
|
||||
embeddings: OpenAIEmbeddings, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
|
||||
@_async_retry_decorator(embeddings)
|
||||
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = await embeddings.client.acreate(**kwargs)
|
||||
return _check_response(response)
|
||||
|
||||
return await _async_embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""OpenAI embedding models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``OPENAI_API_KEY`` set with your API key or pass it
|
||||
as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
openai = OpenAIEmbeddings(openai_api_key="my-api-key")
|
||||
|
||||
In order to use the library with Microsoft Azure endpoints, you need to set
|
||||
the OPENAI_API_TYPE, OPENAI_API_BASE, OPENAI_API_KEY and OPENAI_API_VERSION.
|
||||
The OPENAI_API_TYPE must be set to 'azure' and the others correspond to
|
||||
the properties of your endpoint.
|
||||
In addition, the deployment name must be passed as the model parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/"
|
||||
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
|
||||
os.environ["OPENAI_API_VERSION"] = "2023-05-15"
|
||||
os.environ["OPENAI_PROXY"] = "http://your-corporate-proxy:8080"
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
deployment="your-embeddings-deployment-name",
|
||||
model="your-embeddings-model-name",
|
||||
openai_api_base="https://your-endpoint.openai.azure.com/",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
text = "This is a test query."
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "text-embedding-ada-002"
|
||||
deployment: str = (
|
||||
model # to support Azure OpenAI Service custom deployment names
|
||||
)
|
||||
openai_api_version: str | None = None
|
||||
# to support Azure OpenAI Service custom endpoints
|
||||
openai_api_base: str | None = None
|
||||
# to support Azure OpenAI Service custom endpoints
|
||||
openai_api_type: str | None = None
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: str | None = None
|
||||
embedding_ctx_length: int = 8191
|
||||
"""The maximum number of tokens to embed at once."""
|
||||
openai_api_key: str | None = None
|
||||
openai_organization: str | None = None
|
||||
allowed_special: Literal["all"] | set[str] = set()
|
||||
disallowed_special: Literal["all"] | set[str] | Sequence[str] = "all"
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: float | tuple[float, float] | None = None
|
||||
"""Timeout in seconds for the OpenAPI request."""
|
||||
headers: Any = None
|
||||
tiktoken_model_name: str | None = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(
|
||||
extra.keys()
|
||||
)
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be"
|
||||
" specified explicitly. Instead they were passed in"
|
||||
" as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@model_validator()
|
||||
@classmethod
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_type",
|
||||
"OPENAI_API_TYPE",
|
||||
default="",
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
if values["openai_api_type"] in (
|
||||
"azure",
|
||||
"azure_ad",
|
||||
"azuread",
|
||||
):
|
||||
default_api_version = "2022-12-01"
|
||||
else:
|
||||
default_api_version = ""
|
||||
values["openai_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_version",
|
||||
"OPENAI_API_VERSION",
|
||||
default=default_api_version,
|
||||
)
|
||||
values["openai_organization"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_organization",
|
||||
"OPENAI_ORGANIZATION",
|
||||
default="",
|
||||
)
|
||||
try:
|
||||
import llm
|
||||
|
||||
values["client"] = llm.Embedding
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> dict:
|
||||
openai_args = {
|
||||
"model": self.model,
|
||||
"request_timeout": self.request_timeout,
|
||||
"headers": self.headers,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||
openai_args["engine"] = self.deployment
|
||||
if self.openai_proxy:
|
||||
import llm
|
||||
|
||||
llm.proxy = {
|
||||
"http": self.openai_proxy,
|
||||
"https": self.openai_proxy,
|
||||
} # type: ignore[assignment] # noqa: E501
|
||||
return openai_args
|
||||
|
||||
def _get_len_safe_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
*,
|
||||
engine: str,
|
||||
chunk_size: int | None = None,
|
||||
) -> list[list[float]]:
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for OpenAIEmbeddings. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Warning: model not found. Using cl100k_base" " encoding."
|
||||
)
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
token = encoding.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
for j in range(0, len(token), self.embedding_ctx_length):
|
||||
tokens.append(token[j : j + self.embedding_ctx_length])
|
||||
indices.append(i)
|
||||
|
||||
batched_embeddings: list[list[float]] = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
|
||||
if self.show_progress_bar:
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
_iter = tqdm.tqdm(range(0, len(tokens), _chunk_size))
|
||||
except ImportError:
|
||||
_iter = range(0, len(tokens), _chunk_size)
|
||||
else:
|
||||
_iter = range(0, len(tokens), _chunk_size)
|
||||
|
||||
for i in _iter:
|
||||
response = embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
**self._invocation_params,
|
||||
)
|
||||
batched_embeddings.extend(
|
||||
r["embedding"] for r in response["data"]
|
||||
)
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [
|
||||
[] for _ in range(len(texts))
|
||||
]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = embed_with_retry(
|
||||
self,
|
||||
input="",
|
||||
**self._invocation_params,
|
||||
)["data"][0]["embedding"]
|
||||
else:
|
||||
average = np.average(
|
||||
_result, axis=0, weights=num_tokens_in_batch[i]
|
||||
)
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
async def _aget_len_safe_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
*,
|
||||
engine: str,
|
||||
chunk_size: int | None = None,
|
||||
) -> list[list[float]]:
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for OpenAIEmbeddings. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Warning: model not found. Using cl100k_base" " encoding."
|
||||
)
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
token = encoding.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
for j in range(0, len(token), self.embedding_ctx_length):
|
||||
tokens.append(token[j : j + self.embedding_ctx_length])
|
||||
indices.append(i)
|
||||
|
||||
batched_embeddings: list[list[float]] = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
for i in range(0, len(tokens), _chunk_size):
|
||||
response = await async_embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
**self._invocation_params,
|
||||
)
|
||||
batched_embeddings.extend(
|
||||
r["embedding"] for r in response["data"]
|
||||
)
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [
|
||||
[] for _ in range(len(texts))
|
||||
]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = (
|
||||
await async_embed_with_retry(
|
||||
self,
|
||||
input="",
|
||||
**self._invocation_params,
|
||||
)
|
||||
)["data"][0]["embedding"]
|
||||
else:
|
||||
average = np.average(
|
||||
_result, axis=0, weights=num_tokens_in_batch[i]
|
||||
)
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = 0
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = 0
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
return await self._aget_len_safe_embeddings(
|
||||
texts, engine=self.deployment
|
||||
)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
__all__ = [
|
||||
"OpenAIEmbeddings",
|
||||
]
|
@ -1,182 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from langchain_community.llms.google_palm import GooglePalm
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from pydantic import model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
try:
|
||||
import google.api_core.exceptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-api-core python package. "
|
||||
"Please install it with `pip install google-api-core`."
|
||||
)
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(
|
||||
multiplier=multiplier, min=min_seconds, max=max_seconds
|
||||
),
|
||||
retry=(
|
||||
retry_if_exception_type(
|
||||
google.api_core.exceptions.ResourceExhausted
|
||||
)
|
||||
| retry_if_exception_type(
|
||||
google.api_core.exceptions.ServiceUnavailable
|
||||
)
|
||||
| retry_if_exception_type(
|
||||
google.api_core.exceptions.GoogleAPIError
|
||||
)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _generate_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(**kwargs)
|
||||
|
||||
return _generate_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
"""Strip erroneous leading spaces from text.
|
||||
|
||||
The PaLM API will sometimes erroneously return a single leading space in all
|
||||
lines > 1. This function strips that space.
|
||||
"""
|
||||
has_leading_space = all(
|
||||
not line or line[0] == " " for line in text.split("\n")[1:]
|
||||
)
|
||||
if has_leading_space:
|
||||
return text.replace("\n ", "\n")
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
class GooglePalm(BaseLLM, BaseModel):
|
||||
"""Google PaLM models."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: str | None
|
||||
model_name: str = "models/text-bison-001"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""Run inference with this temperature. Must by in the closed interval
|
||||
[0.0, 1.0]."""
|
||||
top_p: float | None = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: int | None = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
max_output_tokens: int | None = None
|
||||
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||
If unset, will default to 64."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
|
||||
@model_validator()
|
||||
@classmethod
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate api key, python package exists."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-generativeai python package."
|
||||
" Please install it with `pip install"
|
||||
" google-generativeai`."
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
|
||||
if (
|
||||
values["temperature"] is not None
|
||||
and not 0 <= values["temperature"] <= 1
|
||||
):
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
if (
|
||||
values["max_output_tokens"] is not None
|
||||
and values["max_output_tokens"] <= 0
|
||||
):
|
||||
raise ValueError("max_output_tokens must be greater than zero")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
completion = generate_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stop_sequences=stop,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt_generations = []
|
||||
for candidate in completion.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "google_palm"
|
||||
__all__ = [
|
||||
"GooglePalm",
|
||||
]
|
@ -1,16 +0,0 @@
|
||||
"""
|
||||
|
||||
TROCR for Multi-Modal OCR tasks
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class TrOCR:
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
pass
|
||||
|
||||
def __call__(self):
|
||||
pass
|
@ -1,52 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
||||
|
||||
|
||||
class UltralyticsModel(BaseMultiModalModel):
|
||||
"""
|
||||
Initializes an instance of the Ultralytics model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "yolov8n.pt", *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_name = model_name
|
||||
|
||||
try:
|
||||
self.model = YOLO(model_name, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to initialize Ultralytics model: {str(e)}"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, task: str, tasks: List[str] = None, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Calls the Ultralytics model.
|
||||
|
||||
Args:
|
||||
task (str): The task to perform.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result of the model call.
|
||||
"""
|
||||
try:
|
||||
if tasks:
|
||||
return self.model([tasks], *args, **kwargs)
|
||||
else:
|
||||
return self.model(task, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to perform task '{task}' with Ultralytics"
|
||||
f" model: {str(e)}"
|
||||
)
|
@ -1,237 +0,0 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
|
||||
class WizardLLMStoryTeller:
|
||||
"""
|
||||
A class for running inference on a given model.
|
||||
|
||||
Attributes:
|
||||
model_id (str): The ID of the model.
|
||||
device (str): The device to run the model on (either 'cuda' or 'cpu').
|
||||
max_length (int): The maximum length of the output sequence.
|
||||
quantize (bool, optional): Whether to use quantization. Defaults to False.
|
||||
quantization_config (dict, optional): The configuration for quantization.
|
||||
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
||||
logger (logging.Logger, optional): The logger to use. Defaults to a basic logger.
|
||||
|
||||
# Usage
|
||||
```
|
||||
from finetuning_suite import Inference
|
||||
|
||||
model_id = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF"
|
||||
inference = Inference(model_id=model_id)
|
||||
|
||||
prompt_text = "Once upon a time"
|
||||
generated_text = inference(prompt_text)
|
||||
print(generated_text)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF",
|
||||
device: str = None,
|
||||
max_length: int = 500,
|
||||
quantize: bool = False,
|
||||
quantization_config: dict = None,
|
||||
verbose=False,
|
||||
# logger=None,
|
||||
distributed=False,
|
||||
decoding=False,
|
||||
):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.device = (
|
||||
device
|
||||
if device
|
||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
self.model_id = model_id
|
||||
self.max_length = max_length
|
||||
self.verbose = verbose
|
||||
self.distributed = distributed
|
||||
self.decoding = decoding
|
||||
self.model, self.tokenizer = None, None
|
||||
# self.log = Logging()
|
||||
|
||||
if self.distributed:
|
||||
assert (
|
||||
torch.cuda.device_count() > 1
|
||||
), "You need more than 1 gpu for distributed processing"
|
||||
|
||||
bnb_config = None
|
||||
if quantize:
|
||||
if not quantization_config:
|
||||
quantization_config = {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_use_double_quant": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
}
|
||||
bnb_config = BitsAndBytesConfig(**quantization_config)
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, quantization_config=bnb_config
|
||||
)
|
||||
|
||||
self.model # .to(self.device)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to load the model or the tokenizer: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def load_model(self):
|
||||
"""Load the model"""
|
||||
if not self.model or not self.tokenizer:
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_id
|
||||
)
|
||||
|
||||
bnb_config = (
|
||||
BitsAndBytesConfig(**self.quantization_config)
|
||||
if self.quantization_config
|
||||
else None
|
||||
)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, quantization_config=bnb_config
|
||||
).to(self.device)
|
||||
|
||||
if self.distributed:
|
||||
self.model = DDP(self.model)
|
||||
except Exception as error:
|
||||
self.logger.error(
|
||||
"Failed to load the model or the tokenizer:"
|
||||
f" {error}"
|
||||
)
|
||||
raise
|
||||
|
||||
def run(self, prompt_text: str):
|
||||
"""
|
||||
Generate a response based on the prompt text.
|
||||
|
||||
Args:
|
||||
- prompt_text (str): Text to prompt the model.
|
||||
- max_length (int): Maximum length of the response.
|
||||
|
||||
Returns:
|
||||
- Generated text (str).
|
||||
"""
|
||||
self.load_model()
|
||||
|
||||
max_length = self.max_length
|
||||
|
||||
try:
|
||||
inputs = self.tokenizer.encode(
|
||||
prompt_text, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# self.log.start()
|
||||
|
||||
if self.decoding:
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length):
|
||||
output_sequence = []
|
||||
|
||||
outputs = self.model.generate(
|
||||
inputs,
|
||||
max_length=len(inputs) + 1,
|
||||
do_sample=True,
|
||||
)
|
||||
output_tokens = outputs[0][-1]
|
||||
output_sequence.append(output_tokens.item())
|
||||
|
||||
# print token in real-time
|
||||
print(
|
||||
self.tokenizer.decode(
|
||||
[output_tokens],
|
||||
skip_special_tokens=True,
|
||||
),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
inputs = outputs
|
||||
else:
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
inputs, max_length=max_length, do_sample=True
|
||||
)
|
||||
|
||||
del inputs
|
||||
return self.tokenizer.decode(
|
||||
outputs[0], skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate the text: {e}")
|
||||
raise
|
||||
|
||||
def __call__(self, prompt_text: str):
|
||||
"""
|
||||
Generate a response based on the prompt text.
|
||||
|
||||
Args:
|
||||
- prompt_text (str): Text to prompt the model.
|
||||
- max_length (int): Maximum length of the response.
|
||||
|
||||
Returns:
|
||||
- Generated text (str).
|
||||
"""
|
||||
self.load_model()
|
||||
|
||||
max_length = self.max_
|
||||
|
||||
try:
|
||||
inputs = self.tokenizer.encode(
|
||||
prompt_text, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# self.log.start()
|
||||
|
||||
if self.decoding:
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length):
|
||||
output_sequence = []
|
||||
|
||||
outputs = self.model.generate(
|
||||
inputs,
|
||||
max_length=len(inputs) + 1,
|
||||
do_sample=True,
|
||||
)
|
||||
output_tokens = outputs[0][-1]
|
||||
output_sequence.append(output_tokens.item())
|
||||
|
||||
# print token in real-time
|
||||
print(
|
||||
self.tokenizer.decode(
|
||||
[output_tokens],
|
||||
skip_special_tokens=True,
|
||||
),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
inputs = outputs
|
||||
else:
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
inputs, max_length=max_length, do_sample=True
|
||||
)
|
||||
|
||||
del inputs
|
||||
|
||||
return self.tokenizer.decode(
|
||||
outputs[0], skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate the text: {e}")
|
||||
raise
|
@ -1,288 +0,0 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
|
||||
class YarnMistral128:
|
||||
"""
|
||||
A class for running inference on a given model.
|
||||
|
||||
Attributes:
|
||||
model_id (str): The ID of the model.
|
||||
device (str): The device to run the model on (either 'cuda' or 'cpu').
|
||||
max_length (int): The maximum length of the output sequence.
|
||||
quantize (bool, optional): Whether to use quantization. Defaults to False.
|
||||
quantization_config (dict, optional): The configuration for quantization.
|
||||
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
||||
logger (logging.Logger, optional): The logger to use. Defaults to a basic logger.
|
||||
|
||||
# Usage
|
||||
```
|
||||
from finetuning_suite import Inference
|
||||
|
||||
model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
||||
inference = Inference(model_id=model_id)
|
||||
|
||||
prompt_text = "Once upon a time"
|
||||
generated_text = inference(prompt_text)
|
||||
print(generated_text)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "NousResearch/Yarn-Mistral-7b-128k",
|
||||
device: str = None,
|
||||
max_length: int = 500,
|
||||
quantize: bool = False,
|
||||
quantization_config: dict = None,
|
||||
verbose=False,
|
||||
# logger=None,
|
||||
distributed=False,
|
||||
decoding=False,
|
||||
):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.device = (
|
||||
device
|
||||
if device
|
||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
self.model_id = model_id
|
||||
self.max_length = max_length
|
||||
self.verbose = verbose
|
||||
self.distributed = distributed
|
||||
self.decoding = decoding
|
||||
self.model, self.tokenizer = None, None
|
||||
# self.log = Logging()
|
||||
|
||||
if self.distributed:
|
||||
assert (
|
||||
torch.cuda.device_count() > 1
|
||||
), "You need more than 1 gpu for distributed processing"
|
||||
|
||||
bnb_config = None
|
||||
if quantize:
|
||||
if not quantization_config:
|
||||
quantization_config = {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_use_double_quant": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
}
|
||||
bnb_config = BitsAndBytesConfig(**quantization_config)
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id,
|
||||
quantization_config=bnb_config,
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.model # .to(self.device)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to load the model or the tokenizer: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def load_model(self):
|
||||
"""Load the model"""
|
||||
if not self.model or not self.tokenizer:
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_id
|
||||
)
|
||||
|
||||
bnb_config = (
|
||||
BitsAndBytesConfig(**self.quantization_config)
|
||||
if self.quantization_config
|
||||
else None
|
||||
)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, quantization_config=bnb_config
|
||||
).to(self.device)
|
||||
|
||||
if self.distributed:
|
||||
self.model = DDP(self.model)
|
||||
except Exception as error:
|
||||
self.logger.error(
|
||||
"Failed to load the model or the tokenizer:"
|
||||
f" {error}"
|
||||
)
|
||||
raise
|
||||
|
||||
def run(self, prompt_text: str):
|
||||
"""
|
||||
Generate a response based on the prompt text.
|
||||
|
||||
Args:
|
||||
- prompt_text (str): Text to prompt the model.
|
||||
- max_length (int): Maximum length of the response.
|
||||
|
||||
Returns:
|
||||
- Generated text (str).
|
||||
"""
|
||||
self.load_model()
|
||||
|
||||
max_length = self.max_length
|
||||
|
||||
try:
|
||||
inputs = self.tokenizer.encode(
|
||||
prompt_text, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# self.log.start()
|
||||
|
||||
if self.decoding:
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length):
|
||||
output_sequence = []
|
||||
|
||||
outputs = self.model.generate(
|
||||
inputs,
|
||||
max_length=len(inputs) + 1,
|
||||
do_sample=True,
|
||||
)
|
||||
output_tokens = outputs[0][-1]
|
||||
output_sequence.append(output_tokens.item())
|
||||
|
||||
# print token in real-time
|
||||
print(
|
||||
self.tokenizer.decode(
|
||||
[output_tokens],
|
||||
skip_special_tokens=True,
|
||||
),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
inputs = outputs
|
||||
else:
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
inputs, max_length=max_length, do_sample=True
|
||||
)
|
||||
|
||||
del inputs
|
||||
return self.tokenizer.decode(
|
||||
outputs[0], skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate the text: {e}")
|
||||
raise
|
||||
|
||||
async def run_async(self, task: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
Run the model asynchronously
|
||||
|
||||
Args:
|
||||
task (str): Task to run.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
|
||||
Examples:
|
||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
||||
'Once upon a time in a land far, far away...'
|
||||
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
|
||||
['In the deep jungles,',
|
||||
'At the heart of the city,']
|
||||
>>> mpt_instance.freeze_model()
|
||||
>>> mpt_instance.unfreeze_model()
|
||||
|
||||
"""
|
||||
# Wrapping synchronous calls with async
|
||||
return self.run(task, *args, **kwargs)
|
||||
|
||||
def __call__(self, prompt_text: str):
|
||||
"""
|
||||
Generate a response based on the prompt text.
|
||||
|
||||
Args:
|
||||
- prompt_text (str): Text to prompt the model.
|
||||
- max_length (int): Maximum length of the response.
|
||||
|
||||
Returns:
|
||||
- Generated text (str).
|
||||
"""
|
||||
self.load_model()
|
||||
|
||||
max_length = self.max_
|
||||
|
||||
try:
|
||||
inputs = self.tokenizer.encode(
|
||||
prompt_text, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# self.log.start()
|
||||
|
||||
if self.decoding:
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length):
|
||||
output_sequence = []
|
||||
|
||||
outputs = self.model.generate(
|
||||
inputs,
|
||||
max_length=len(inputs) + 1,
|
||||
do_sample=True,
|
||||
)
|
||||
output_tokens = outputs[0][-1]
|
||||
output_sequence.append(output_tokens.item())
|
||||
|
||||
# print token in real-time
|
||||
print(
|
||||
self.tokenizer.decode(
|
||||
[output_tokens],
|
||||
skip_special_tokens=True,
|
||||
),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
inputs = outputs
|
||||
else:
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
inputs, max_length=max_length, do_sample=True
|
||||
)
|
||||
|
||||
del inputs
|
||||
|
||||
return self.tokenizer.decode(
|
||||
outputs[0], skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate the text: {e}")
|
||||
raise
|
||||
|
||||
async def __call_async__(self, task: str, *args, **kwargs) -> str:
|
||||
"""Call the model asynchronously""" ""
|
||||
return await self.run_async(task, *args, **kwargs)
|
||||
|
||||
def save_model(self, path: str):
|
||||
"""Save the model to a given path"""
|
||||
self.model.save_pretrained(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
def gpu_available(self) -> bool:
|
||||
"""Check if GPU is available"""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
def memory_consumption(self) -> dict:
|
||||
"""Get the memory consumption of the GPU"""
|
||||
if self.gpu_available():
|
||||
torch.cuda.synchronize()
|
||||
allocated = torch.cuda.memory_allocated()
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
return {"allocated": allocated, "reserved": reserved}
|
||||
else:
|
||||
return {"error": "GPU not available"}
|
@ -1,97 +0,0 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class Yi34B200k:
|
||||
"""
|
||||
A class for eaasy interaction with Yi34B200k
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
model_id: str
|
||||
The model id of the model to be used.
|
||||
device_map: str
|
||||
The device to be used for inference.
|
||||
torch_dtype: str
|
||||
The torch dtype to be used for inference.
|
||||
max_length: int
|
||||
The maximum length of the generated text.
|
||||
repitition_penalty: float
|
||||
The repitition penalty to be used for inference.
|
||||
no_repeat_ngram_size: int
|
||||
The no repeat ngram size to be used for inference.
|
||||
temperature: float
|
||||
The temperature to be used for inference.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
__call__(self, task: str) -> str:
|
||||
Generates text based on the given prompt.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "01-ai/Yi-34B-200K",
|
||||
device_map: str = "auto",
|
||||
torch_dtype: str = "auto",
|
||||
max_length: int = 512,
|
||||
repitition_penalty: float = 1.3,
|
||||
no_repeat_ngram_size: int = 5,
|
||||
temperature: float = 0.7,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.8,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
self.device_map = device_map
|
||||
self.torch_dtype = torch_dtype
|
||||
self.max_length = max_length
|
||||
self.repitition_penalty = repitition_penalty
|
||||
self.no_repeat_ngram_size = no_repeat_ngram_size
|
||||
self.temperature = temperature
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
def __call__(self, task: str):
|
||||
"""
|
||||
Generates text based on the given prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text prompt.
|
||||
max_length (int): The maximum length of the generated text.
|
||||
|
||||
Returns:
|
||||
str: The generated text.
|
||||
"""
|
||||
inputs = self.tokenizer(task, return_tensors="pt")
|
||||
outputs = self.model.generate(
|
||||
inputs.input_ids.cuda(),
|
||||
max_length=self.max_length,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
do_sample=True,
|
||||
repetition_penalty=self.repitition_penalty,
|
||||
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
||||
temperature=self.temperature,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
# # Example usage
|
||||
# yi34b = Yi34B200k()
|
||||
# prompt = "There's a place where time stands still. A place of breathtaking wonder, but also"
|
||||
# generated_text = yi34b(prompt)
|
||||
# print(generated_text)
|
@ -1,223 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
# Import necessary modules
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import BioGptForCausalLM, BioGptTokenizer
|
||||
|
||||
|
||||
# Fixture for BioGPT instance
|
||||
@pytest.fixture
|
||||
def biogpt_instance():
|
||||
from swarms.models import BioGPT
|
||||
|
||||
return BioGPT()
|
||||
|
||||
|
||||
# 36. Test if BioGPT provides a response for a simple biomedical question
|
||||
def test_biomedical_response_1(biogpt_instance):
|
||||
question = "What are the functions of the mitochondria?"
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 37. Test for a genetics-based question
|
||||
def test_genetics_response(biogpt_instance):
|
||||
question = "Can you explain the Mendelian inheritance?"
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 38. Test for a question about viruses
|
||||
def test_virus_response(biogpt_instance):
|
||||
question = "How do RNA viruses replicate?"
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 39. Test for a cell biology related question
|
||||
def test_cell_biology_response(biogpt_instance):
|
||||
question = "Describe the cell cycle and its phases."
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 40. Test for a question about protein structure
|
||||
def test_protein_structure_response(biogpt_instance):
|
||||
question = (
|
||||
"What's the difference between alpha helix and beta sheet"
|
||||
" structures in proteins?"
|
||||
)
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 41. Test for a pharmacology question
|
||||
def test_pharmacology_response(biogpt_instance):
|
||||
question = "How do beta blockers work?"
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 42. Test for an anatomy-based question
|
||||
def test_anatomy_response(biogpt_instance):
|
||||
question = "Describe the structure of the human heart."
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 43. Test for a question about bioinformatics
|
||||
def test_bioinformatics_response(biogpt_instance):
|
||||
question = "What is a BLAST search?"
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 44. Test for a neuroscience question
|
||||
def test_neuroscience_response(biogpt_instance):
|
||||
question = "Explain the function of synapses in the nervous system."
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# 45. Test for an immunology question
|
||||
def test_immunology_response(biogpt_instance):
|
||||
question = "What is the role of T cells in the immune response?"
|
||||
response = biogpt_instance(question)
|
||||
assert response
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_init(bio_gpt):
|
||||
assert bio_gpt.model_name == "microsoft/biogpt"
|
||||
assert bio_gpt.max_length == 500
|
||||
assert bio_gpt.num_return_sequences == 5
|
||||
assert bio_gpt.do_sample is True
|
||||
assert bio_gpt.min_length == 100
|
||||
|
||||
|
||||
def test_call(bio_gpt, monkeypatch):
|
||||
def mock_pipeline(*args, **kwargs):
|
||||
class MockGenerator:
|
||||
def __call__(self, text, **kwargs):
|
||||
return ["Generated text"]
|
||||
|
||||
return MockGenerator()
|
||||
|
||||
monkeypatch.setattr("transformers.pipeline", mock_pipeline)
|
||||
result = bio_gpt("Input text")
|
||||
assert result == ["Generated text"]
|
||||
|
||||
|
||||
def test_get_features(bio_gpt):
|
||||
features = bio_gpt.get_features("Input text")
|
||||
assert "last_hidden_state" in features
|
||||
|
||||
|
||||
def test_beam_search_decoding(bio_gpt):
|
||||
generated_text = bio_gpt.beam_search_decoding("Input text")
|
||||
assert isinstance(generated_text, str)
|
||||
|
||||
|
||||
def test_set_pretrained_model(bio_gpt):
|
||||
bio_gpt.set_pretrained_model("new_model")
|
||||
assert bio_gpt.model_name == "new_model"
|
||||
|
||||
|
||||
def test_get_config(bio_gpt):
|
||||
config = bio_gpt.get_config()
|
||||
assert "vocab_size" in config
|
||||
|
||||
|
||||
def test_save_load_model(tmp_path, bio_gpt):
|
||||
bio_gpt.save_model(tmp_path)
|
||||
bio_gpt.load_from_path(tmp_path)
|
||||
assert bio_gpt.model_name == "microsoft/biogpt"
|
||||
|
||||
|
||||
def test_print_model(capsys, bio_gpt):
|
||||
bio_gpt.print_model()
|
||||
captured = capsys.readouterr()
|
||||
assert "BioGptForCausalLM" in captured.out
|
||||
|
||||
|
||||
# 26. Test if set_pretrained_model changes the model_name
|
||||
def test_set_pretrained_model_name_change(biogpt_instance):
|
||||
biogpt_instance.set_pretrained_model("new_model_name")
|
||||
assert biogpt_instance.model_name == "new_model_name"
|
||||
|
||||
|
||||
# 27. Test get_config return type
|
||||
def test_get_config_return_type(biogpt_instance):
|
||||
config = biogpt_instance.get_config()
|
||||
assert isinstance(config, type(biogpt_instance.model.config))
|
||||
|
||||
|
||||
# 28. Test saving model functionality by checking if files are created
|
||||
@patch.object(BioGptForCausalLM, "save_pretrained")
|
||||
@patch.object(BioGptTokenizer, "save_pretrained")
|
||||
def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance):
|
||||
path = "test_path"
|
||||
biogpt_instance.save_model(path)
|
||||
mock_save_model.assert_called_once_with(path)
|
||||
mock_save_tokenizer.assert_called_once_with(path)
|
||||
|
||||
|
||||
# 29. Test loading model from path
|
||||
@patch.object(BioGptForCausalLM, "from_pretrained")
|
||||
@patch.object(BioGptTokenizer, "from_pretrained")
|
||||
def test_load_from_path(
|
||||
mock_load_model, mock_load_tokenizer, biogpt_instance
|
||||
):
|
||||
path = "test_path"
|
||||
biogpt_instance.load_from_path(path)
|
||||
mock_load_model.assert_called_once_with(path)
|
||||
mock_load_tokenizer.assert_called_once_with(path)
|
||||
|
||||
|
||||
# 30. Test print_model doesn't raise any error
|
||||
def test_print_model_metadata(biogpt_instance):
|
||||
try:
|
||||
biogpt_instance.print_model()
|
||||
except Exception as e:
|
||||
pytest.fail(f"print_model() raised an exception: {e}")
|
||||
|
||||
|
||||
# 31. Test that beam_search_decoding uses the correct number of beams
|
||||
@patch.object(BioGptForCausalLM, "generate")
|
||||
def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance):
|
||||
biogpt_instance.beam_search_decoding("test_sentence", num_beams=7)
|
||||
_, kwargs = mock_generate.call_args
|
||||
assert kwargs["num_beams"] == 7
|
||||
|
||||
|
||||
# 32. Test if beam_search_decoding handles early_stopping
|
||||
@patch.object(BioGptForCausalLM, "generate")
|
||||
def test_beam_search_decoding_early_stopping(
|
||||
mock_generate, biogpt_instance
|
||||
):
|
||||
biogpt_instance.beam_search_decoding(
|
||||
"test_sentence", early_stopping=False
|
||||
)
|
||||
_, kwargs = mock_generate.call_args
|
||||
assert kwargs["early_stopping"] is False
|
||||
|
||||
|
||||
# 33. Test get_features return type
|
||||
def test_get_features_return_type(biogpt_instance):
|
||||
result = biogpt_instance.get_features("This is a sample text.")
|
||||
assert isinstance(result, torch.nn.modules.module.Module)
|
||||
|
||||
|
||||
# 34. Test if default model is set correctly during initialization
|
||||
def test_default_model_name(biogpt_instance):
|
||||
assert biogpt_instance.model_name == "microsoft/biogpt"
|
@ -1,96 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models.eleven_labs import (
|
||||
ElevenLabsModel,
|
||||
ElevenLabsText2SpeechTool,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Define some test data
|
||||
SAMPLE_TEXT = "Hello, this is a test."
|
||||
API_KEY = os.environ.get("ELEVEN_API_KEY")
|
||||
EXPECTED_SPEECH_FILE = "expected_speech.wav"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eleven_labs_tool():
|
||||
return ElevenLabsText2SpeechTool()
|
||||
|
||||
|
||||
# Basic functionality tests
|
||||
def test_run_text_to_speech(eleven_labs_tool):
|
||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert isinstance(speech_file, str)
|
||||
assert speech_file.endswith(".wav")
|
||||
|
||||
|
||||
def test_play_speech(eleven_labs_tool):
|
||||
with patch("builtins.open", mock_open(read_data="fake_audio_data")):
|
||||
eleven_labs_tool.play(EXPECTED_SPEECH_FILE)
|
||||
|
||||
|
||||
def test_stream_speech(eleven_labs_tool):
|
||||
with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file:
|
||||
eleven_labs_tool.stream_speech(SAMPLE_TEXT)
|
||||
mock_file.assert_called_with(
|
||||
mode="bx", suffix=".wav", delete=False
|
||||
)
|
||||
|
||||
|
||||
# Testing fixture and environment variables
|
||||
def test_api_key_validation(eleven_labs_tool):
|
||||
with patch(
|
||||
"langchain.utils.get_from_dict_or_env", return_value=API_KEY
|
||||
):
|
||||
values = {"eleven_api_key": None}
|
||||
validated_values = eleven_labs_tool.validate_environment(values)
|
||||
assert "eleven_api_key" in validated_values
|
||||
|
||||
|
||||
# Mocking the external library
|
||||
def test_run_text_to_speech_with_mock(eleven_labs_tool):
|
||||
with patch(
|
||||
"tempfile.NamedTemporaryFile", mock_open()
|
||||
) as mock_file, patch(
|
||||
"your_module._import_elevenlabs"
|
||||
) as mock_elevenlabs:
|
||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
||||
mock_elevenlabs_instance.generate.return_value = b"fake_audio_data"
|
||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert mock_file.call_args[1]["suffix"] == ".wav"
|
||||
assert mock_file.call_args[1]["delete"] is False
|
||||
assert mock_file().write.call_args[0][0] == b"fake_audio_data"
|
||||
|
||||
|
||||
# Exception testing
|
||||
def test_run_text_to_speech_error_handling(eleven_labs_tool):
|
||||
with patch("your_module._import_elevenlabs") as mock_elevenlabs:
|
||||
mock_elevenlabs_instance = mock_elevenlabs.return_value
|
||||
mock_elevenlabs_instance.generate.side_effect = Exception(
|
||||
"Test Exception"
|
||||
)
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match=(
|
||||
"Error while running ElevenLabsText2SpeechTool: Test"
|
||||
" Exception"
|
||||
),
|
||||
):
|
||||
eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
|
||||
|
||||
# Parameterized testing
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL],
|
||||
)
|
||||
def test_run_text_to_speech_with_different_models(eleven_labs_tool, model):
|
||||
eleven_labs_tool.model = model
|
||||
speech_file = eleven_labs_tool.run(SAMPLE_TEXT)
|
||||
assert isinstance(speech_file, str)
|
||||
assert speech_file.endswith(".wav")
|
@ -1,177 +0,0 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from swarms.models.gigabind import Gigabind
|
||||
|
||||
try:
|
||||
import requests_mock
|
||||
except ImportError:
|
||||
requests_mock = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api():
|
||||
return Gigabind(host="localhost", port=8000, endpoint="embeddings")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock(requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings", json={"result": "success"}
|
||||
)
|
||||
return requests_mock
|
||||
|
||||
|
||||
def test_run_with_text(api, mock):
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_vision(api, mock):
|
||||
response = api.run(vision="image.jpg")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_audio(api, mock):
|
||||
response = api.run(audio="audio.mp3")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_all(api, mock):
|
||||
response = api.run(
|
||||
text="Hello, world!", vision="image.jpg", audio="audio.mp3"
|
||||
)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_none(api):
|
||||
with pytest.raises(ValueError):
|
||||
api.run()
|
||||
|
||||
|
||||
def test_generate_summary(api, mock):
|
||||
response = api.generate_summary(text="Hello, world!")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_generate_summary_with_none(api):
|
||||
with pytest.raises(ValueError):
|
||||
api.generate_summary()
|
||||
|
||||
|
||||
def test_retry_on_failure(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings",
|
||||
[
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 200, "json": {"result": "success"}},
|
||||
],
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_retry_exhausted(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings",
|
||||
[
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 500, "json": {}},
|
||||
],
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_proxy_url(api):
|
||||
api.proxy_url = "http://proxy:8080"
|
||||
assert api.url == "http://proxy:8080"
|
||||
|
||||
|
||||
def test_invalid_response(api, requests_mock):
|
||||
requests_mock.post("http://localhost:8000/embeddings", text="not json")
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_connection_error(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings",
|
||||
exc=requests.exceptions.ConnectTimeout,
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_http_error(api, requests_mock):
|
||||
requests_mock.post("http://localhost:8000/embeddings", status_code=500)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_url_construction(api):
|
||||
assert api.url == "http://localhost:8000/embeddings"
|
||||
|
||||
|
||||
def test_url_construction_with_proxy(api):
|
||||
api.proxy_url = "http://proxy:8080"
|
||||
assert api.url == "http://proxy:8080"
|
||||
|
||||
|
||||
def test_run_with_large_text(api, mock):
|
||||
large_text = "Hello, world! " * 10000 # 10,000 repetitions
|
||||
response = api.run(text=large_text)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_large_vision(api, mock):
|
||||
large_vision = "image.jpg" * 10000 # 10,000 repetitions
|
||||
response = api.run(vision=large_vision)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_large_audio(api, mock):
|
||||
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
|
||||
response = api.run(audio=large_audio)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_large_all(api, mock):
|
||||
large_text = "Hello, world! " * 10000 # 10,000 repetitions
|
||||
large_vision = "image.jpg" * 10000 # 10,000 repetitions
|
||||
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
|
||||
response = api.run(
|
||||
text=large_text, vision=large_vision, audio=large_audio
|
||||
)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_timeout(api, mock):
|
||||
response = api.run(text="Hello, world!", timeout=0.001)
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_host(api):
|
||||
api.host = "invalid"
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_port(api):
|
||||
api.port = 99999
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_endpoint(api):
|
||||
api.endpoint = "invalid"
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_proxy_url(api):
|
||||
api.proxy_url = "invalid"
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
@ -1,237 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarms.models.huggingface import (
|
||||
HuggingfaceLLM, # Replace with the actual import path
|
||||
)
|
||||
|
||||
|
||||
# Fixture for the class instance
|
||||
@pytest.fixture
|
||||
def llm_instance():
|
||||
model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
||||
instance = HuggingfaceLLM(model_id=model_id)
|
||||
return instance
|
||||
|
||||
|
||||
# Test for instantiation and attributes
|
||||
def test_llm_initialization(llm_instance):
|
||||
assert (
|
||||
llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
||||
)
|
||||
assert llm_instance.max_length == 500
|
||||
# ... add more assertions for all default attributes
|
||||
|
||||
|
||||
# Parameterized test for setting devices
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_llm_set_device(llm_instance, device):
|
||||
llm_instance.set_device(device)
|
||||
assert llm_instance.device == device
|
||||
|
||||
|
||||
# Test exception during initialization with a bad model_id
|
||||
def test_llm_bad_model_initialization():
|
||||
with pytest.raises(Exception):
|
||||
HuggingfaceLLM(model_id="unknown-model")
|
||||
|
||||
|
||||
# # Mocking the tokenizer and model to test run method
|
||||
# @patch("swarms.models.huggingface.AutoTokenizer.from_pretrained")
|
||||
# @patch(
|
||||
# "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained"
|
||||
# )
|
||||
# def test_llm_run(mock_model, mock_tokenizer, llm_instance):
|
||||
# mock_model.return_value.generate.return_value = "mocked output"
|
||||
# mock_tokenizer.return_value.encode.return_value = "mocked input"
|
||||
# result = llm_instance.run("test task")
|
||||
# assert result == "mocked output"
|
||||
|
||||
|
||||
# Async test (requires pytest-asyncio plugin)
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_run_async(llm_instance):
|
||||
result = await llm_instance.run_async("test task")
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# Test for checking GPU availability
|
||||
def test_llm_gpu_availability(llm_instance):
|
||||
# Assuming the test is running on a machine where the GPU availability is known
|
||||
expected_result = torch.cuda.is_available()
|
||||
assert llm_instance.gpu_available() == expected_result
|
||||
|
||||
|
||||
# Test for memory consumption reporting
|
||||
def test_llm_memory_consumption(llm_instance):
|
||||
# Mocking torch.cuda functions for consistent results
|
||||
with patch("torch.cuda.memory_allocated", return_value=1024):
|
||||
with patch("torch.cuda.memory_reserved", return_value=2048):
|
||||
memory = llm_instance.memory_consumption()
|
||||
assert memory == {"allocated": 1024, "reserved": 2048}
|
||||
|
||||
|
||||
# Test different initialization parameters
|
||||
@pytest.mark.parametrize(
|
||||
"model_id, max_length",
|
||||
[
|
||||
("NousResearch/Nous-Hermes-2-Vision-Alpha", 100),
|
||||
("microsoft/Orca-2-13b", 200),
|
||||
(
|
||||
"berkeley-nest/Starling-LM-7B-alpha",
|
||||
None,
|
||||
), # None to check default behavior
|
||||
],
|
||||
)
|
||||
def test_llm_initialization_params(model_id, max_length):
|
||||
if max_length:
|
||||
instance = HuggingfaceLLM(model_id=model_id, max_length=max_length)
|
||||
assert instance.max_length == max_length
|
||||
else:
|
||||
instance = HuggingfaceLLM(model_id=model_id)
|
||||
assert (
|
||||
instance.max_length == 500
|
||||
) # Assuming 500 is the default max_length
|
||||
|
||||
|
||||
# Test for setting an invalid device
|
||||
def test_llm_set_invalid_device(llm_instance):
|
||||
with pytest.raises(ValueError):
|
||||
llm_instance.set_device("quantum_processor")
|
||||
|
||||
|
||||
# Mocking external API call to test run method without network
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_run_without_network(mock_run, llm_instance):
|
||||
mock_run.return_value = "mocked output"
|
||||
result = llm_instance.run("test task without network")
|
||||
assert result == "mocked output"
|
||||
|
||||
|
||||
# Test handling of empty input for the run method
|
||||
def test_llm_run_empty_input(llm_instance):
|
||||
with pytest.raises(ValueError):
|
||||
llm_instance.run("")
|
||||
|
||||
|
||||
# Test the generation with a provided seed for reproducibility
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_run_with_seed(mock_run, llm_instance):
|
||||
seed = 42
|
||||
llm_instance.set_seed(seed)
|
||||
# Assuming set_seed method affects the randomness in the model
|
||||
# You would typically ensure that setting the seed gives reproducible results
|
||||
mock_run.return_value = "mocked deterministic output"
|
||||
result = llm_instance.run("test task", seed=seed)
|
||||
assert result == "mocked deterministic output"
|
||||
|
||||
|
||||
# Test the output length is as expected
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_run_output_length(mock_run, llm_instance):
|
||||
input_text = "test task"
|
||||
llm_instance.max_length = 50 # set a max_length for the output
|
||||
mock_run.return_value = "mocked output" * 10 # some long text
|
||||
result = llm_instance.run(input_text)
|
||||
assert len(result.split()) <= llm_instance.max_length
|
||||
|
||||
|
||||
# Test the tokenizer handling special tokens correctly
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode")
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode")
|
||||
def test_llm_tokenizer_special_tokens(
|
||||
mock_decode, mock_encode, llm_instance
|
||||
):
|
||||
mock_encode.return_value = "encoded input with special tokens"
|
||||
mock_decode.return_value = "decoded output with special tokens"
|
||||
result = llm_instance.run("test task with special tokens")
|
||||
mock_encode.assert_called_once()
|
||||
mock_decode.assert_called_once()
|
||||
assert "special tokens" in result
|
||||
|
||||
|
||||
# Test for correct handling of timeouts
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_timeout_handling(mock_run, llm_instance):
|
||||
mock_run.side_effect = TimeoutError
|
||||
with pytest.raises(TimeoutError):
|
||||
llm_instance.run("test task with timeout")
|
||||
|
||||
|
||||
# Test for response time within a threshold (performance test)
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_response_time(mock_run, llm_instance):
|
||||
import time
|
||||
|
||||
mock_run.return_value = "mocked output"
|
||||
start_time = time.time()
|
||||
llm_instance.run("test task for response time")
|
||||
end_time = time.time()
|
||||
assert (
|
||||
end_time - start_time < 1
|
||||
) # Assuming the response should be faster than 1 second
|
||||
|
||||
|
||||
# Test the logging of a warning for long inputs
|
||||
@patch("swarms.models.huggingface.logging.warning")
|
||||
def test_llm_long_input_warning(mock_warning, llm_instance):
|
||||
long_input = "x" * 10000 # input longer than the typical limit
|
||||
llm_instance.run(long_input)
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
|
||||
# Test for run method behavior when model raises an exception
|
||||
@patch(
|
||||
"swarms.models.huggingface.HuggingfaceLLM._model.generate",
|
||||
side_effect=RuntimeError,
|
||||
)
|
||||
def test_llm_run_model_exception(mock_generate, llm_instance):
|
||||
with pytest.raises(RuntimeError):
|
||||
llm_instance.run("test task when model fails")
|
||||
|
||||
|
||||
# Test the behavior when GPU is forced but not available
|
||||
@patch("torch.cuda.is_available", return_value=False)
|
||||
def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance):
|
||||
with pytest.raises(EnvironmentError):
|
||||
llm_instance.set_device(
|
||||
"cuda"
|
||||
) # Attempt to set CUDA when it's not available
|
||||
|
||||
|
||||
# Test for proper cleanup after model use (releasing resources)
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM._model")
|
||||
def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance):
|
||||
llm_instance.cleanup()
|
||||
# Assuming cleanup method is meant to free resources
|
||||
mock_model.delete.assert_called_once()
|
||||
mock_tokenizer.delete.assert_called_once()
|
||||
|
||||
|
||||
# Test model's ability to handle multilingual input
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_multilingual_input(mock_run, llm_instance):
|
||||
mock_run.return_value = "mocked multilingual output"
|
||||
multilingual_input = "Bonjour, ceci est un test multilingue."
|
||||
result = llm_instance.run(multilingual_input)
|
||||
assert isinstance(
|
||||
result, str
|
||||
) # Simple check to ensure output is string type
|
||||
|
||||
|
||||
# Test caching mechanism to prevent re-running the same inputs
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_caching_mechanism(mock_run, llm_instance):
|
||||
input_text = "test caching mechanism"
|
||||
mock_run.return_value = "cached output"
|
||||
# Run the input twice
|
||||
first_run_result = llm_instance.run(input_text)
|
||||
second_run_result = llm_instance.run(input_text)
|
||||
mock_run.assert_called_once() # Should only be called once due to caching
|
||||
assert first_run_result == second_run_result
|
||||
|
||||
|
||||
# These tests are provided as examples. In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class.
|
||||
# For instance, "mock_model.delete.assert_called_once()" and similar lines are based on hypothetical methods and behaviors that you need to replace with actual implementations.
|
@ -1,84 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarms.models.jina_embeds import JinaEmbeddings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model():
|
||||
return JinaEmbeddings("bert-base-uncased", verbose=True)
|
||||
|
||||
|
||||
def test_initialization(model):
|
||||
assert isinstance(model, JinaEmbeddings)
|
||||
assert model.device in ["cuda", "cpu"]
|
||||
assert model.max_length == 500
|
||||
assert model.verbose is True
|
||||
|
||||
|
||||
def test_run_sync(model):
|
||||
task = "Encode this text"
|
||||
result = model.run(task)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == (model.max_length,)
|
||||
|
||||
|
||||
def test_run_async(model):
|
||||
task = "Encode this text"
|
||||
result = model.run_async(task)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == (model.max_length,)
|
||||
|
||||
|
||||
def test_save_model(tmp_path, model):
|
||||
model_path = tmp_path / "model"
|
||||
model.save_model(model_path)
|
||||
assert (model_path / "config.json").is_file()
|
||||
assert (model_path / "pytorch_model.bin").is_file()
|
||||
assert (model_path / "vocab.txt").is_file()
|
||||
|
||||
|
||||
def test_gpu_available(model):
|
||||
gpu_status = model.gpu_available()
|
||||
if torch.cuda.is_available():
|
||||
assert gpu_status is True
|
||||
else:
|
||||
assert gpu_status is False
|
||||
|
||||
|
||||
def test_memory_consumption(model):
|
||||
memory_stats = model.memory_consumption()
|
||||
if torch.cuda.is_available():
|
||||
assert "allocated" in memory_stats
|
||||
assert "reserved" in memory_stats
|
||||
else:
|
||||
assert "error" in memory_stats
|
||||
|
||||
|
||||
def test_cosine_similarity(model):
|
||||
task1 = "This is a sample text for testing."
|
||||
task2 = "Another sample text for testing."
|
||||
embeddings1 = model.run(task1)
|
||||
embeddings2 = model.run(task2)
|
||||
sim = model.cos_sim(embeddings1, embeddings2)
|
||||
assert isinstance(sim, torch.Tensor)
|
||||
assert sim.item() >= -1.0
|
||||
assert sim.item() <= 1.0
|
||||
|
||||
|
||||
def test_failed_load_model(caplog):
|
||||
with pytest.raises(Exception):
|
||||
JinaEmbeddings("invalid_model")
|
||||
assert "Failed to load the model or the tokenizer" in caplog.text
|
||||
|
||||
|
||||
def test_failed_generate_text(caplog, model):
|
||||
with pytest.raises(Exception):
|
||||
model.run("invalid_task")
|
||||
assert "Failed to generate the text" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_change_device(model, device):
|
||||
model.device = device
|
||||
assert model.device == device
|
@ -1,43 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from swarms.models import TimmModel
|
||||
|
||||
|
||||
def test_timm_model_init():
|
||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
mock_list_models.assert_called_once()
|
||||
assert timm_model.model_name == model_name
|
||||
assert timm_model.pretrained == pretrained
|
||||
assert timm_model.in_chans == in_chans
|
||||
assert timm_model.models == mock_list_models.return_value
|
||||
|
||||
|
||||
def test_timm_model_call():
|
||||
with patch("swarms.models.timm.create_model") as mock_create_model:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
task = torch.rand(1, in_chans, 224, 224)
|
||||
result = timm_model(task)
|
||||
mock_create_model.assert_called_once_with(
|
||||
model_name, pretrained=pretrained, in_chans=in_chans
|
||||
)
|
||||
assert result == mock_create_model.return_value(task)
|
||||
|
||||
|
||||
def test_timm_model_list_models():
|
||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
result = timm_model.list_models()
|
||||
mock_list_models.assert_called_once()
|
||||
assert result == mock_list_models.return_value
|
@ -1,35 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from swarms.models.ultralytics_model import UltralyticsModel
|
||||
|
||||
|
||||
def test_ultralytics_init():
|
||||
with patch("swarms.models.YOLO") as mock_yolo:
|
||||
model_name = "yolov5s"
|
||||
ultralytics = UltralyticsModel(model_name)
|
||||
mock_yolo.assert_called_once_with(model_name)
|
||||
assert ultralytics.model_name == model_name
|
||||
assert ultralytics.model == mock_yolo.return_value
|
||||
|
||||
|
||||
def test_ultralytics_call():
|
||||
with patch("swarms.models.YOLO") as mock_yolo:
|
||||
model_name = "yolov5s"
|
||||
ultralytics = UltralyticsModel(model_name)
|
||||
task = "detect"
|
||||
args = (1, 2, 3)
|
||||
kwargs = {"a": "A", "b": "B"}
|
||||
result = ultralytics(task, *args, **kwargs)
|
||||
mock_yolo.return_value.assert_called_once_with(
|
||||
task, *args, **kwargs
|
||||
)
|
||||
assert result == mock_yolo.return_value.return_value
|
||||
|
||||
|
||||
def test_ultralytics_list_models():
|
||||
with patch("swarms.models.YOLO") as mock_yolo:
|
||||
model_name = "yolov5s"
|
||||
ultralytics = UltralyticsModel(model_name)
|
||||
result = ultralytics.list_models()
|
||||
mock_yolo.list_models.assert_called_once()
|
||||
assert result == mock_yolo.list_models.return_value
|
@ -1,126 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from swarms.models.yi_200k import Yi34B200k
|
||||
|
||||
|
||||
# Create fixtures if needed
|
||||
@pytest.fixture
|
||||
def yi34b_model():
|
||||
return Yi34B200k()
|
||||
|
||||
|
||||
# Test cases for the Yi34B200k class
|
||||
def test_yi34b_init(yi34b_model):
|
||||
assert isinstance(yi34b_model.model, torch.nn.Module)
|
||||
assert isinstance(yi34b_model.tokenizer, AutoTokenizer)
|
||||
|
||||
|
||||
def test_yi34b_generate_text(yi34b_model):
|
||||
prompt = "There's a place where time stands still."
|
||||
generated_text = yi34b_model(prompt)
|
||||
assert isinstance(generated_text, str)
|
||||
assert len(generated_text) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_length", [64, 128, 256, 512])
|
||||
def test_yi34b_generate_text_with_length(yi34b_model, max_length):
|
||||
prompt = "There's a place where time stands still."
|
||||
generated_text = yi34b_model(prompt, max_length=max_length)
|
||||
assert len(generated_text) <= max_length
|
||||
|
||||
|
||||
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
|
||||
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
|
||||
prompt = "There's a place where time stands still."
|
||||
generated_text = yi34b_model(prompt, temperature=temperature)
|
||||
assert isinstance(generated_text, str)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
|
||||
prompt = None # Invalid prompt
|
||||
with pytest.raises(
|
||||
ValueError, match="Input prompt must be a non-empty string"
|
||||
):
|
||||
yi34b_model(prompt)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
|
||||
prompt = "There's a place where time stands still."
|
||||
max_length = -1 # Invalid max_length
|
||||
with pytest.raises(
|
||||
ValueError, match="max_length must be a positive integer"
|
||||
):
|
||||
yi34b_model(prompt, max_length=max_length)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
|
||||
prompt = "There's a place where time stands still."
|
||||
temperature = 2.0 # Invalid temperature
|
||||
with pytest.raises(
|
||||
ValueError, match="temperature must be between 0.01 and 1.0"
|
||||
):
|
||||
yi34b_model(prompt, temperature=temperature)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_k", [20, 30, 50])
|
||||
def test_yi34b_generate_text_with_top_k(yi34b_model, top_k):
|
||||
prompt = "There's a place where time stands still."
|
||||
generated_text = yi34b_model(prompt, top_k=top_k)
|
||||
assert isinstance(generated_text, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9])
|
||||
def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
|
||||
prompt = "There's a place where time stands still."
|
||||
generated_text = yi34b_model(prompt, top_p=top_p)
|
||||
assert isinstance(generated_text, str)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
|
||||
prompt = "There's a place where time stands still."
|
||||
top_k = -1 # Invalid top_k
|
||||
with pytest.raises(
|
||||
ValueError, match="top_k must be a non-negative integer"
|
||||
):
|
||||
yi34b_model(prompt, top_k=top_k)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
|
||||
prompt = "There's a place where time stands still."
|
||||
top_p = 1.5 # Invalid top_p
|
||||
with pytest.raises(
|
||||
ValueError, match="top_p must be between 0.0 and 1.0"
|
||||
):
|
||||
yi34b_model(prompt, top_p=top_p)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
|
||||
def test_yi34b_generate_text_with_repitition_penalty(
|
||||
yi34b_model, repitition_penalty
|
||||
):
|
||||
prompt = "There's a place where time stands still."
|
||||
generated_text = yi34b_model(
|
||||
prompt, repitition_penalty=repitition_penalty
|
||||
)
|
||||
assert isinstance(generated_text, str)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_repitition_penalty(
|
||||
yi34b_model,
|
||||
):
|
||||
prompt = "There's a place where time stands still."
|
||||
repitition_penalty = 0.0 # Invalid repitition_penalty
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="repitition_penalty must be a positive float",
|
||||
):
|
||||
yi34b_model(prompt, repitition_penalty=repitition_penalty)
|
||||
|
||||
|
||||
def test_yi34b_generate_text_with_invalid_device(yi34b_model):
|
||||
prompt = "There's a place where time stands still."
|
||||
device_map = "invalid_device" # Invalid device_map
|
||||
with pytest.raises(ValueError, match="Invalid device_map"):
|
||||
yi34b_model(prompt, device_map=device_map)
|
Loading…
Reference in new issue