You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_bioclip.py

172 lines
5.0 KiB

# Import necessary modules and define fixtures if needed
import os
import pytest
import torch
from PIL import Image
from swarms.models.bioclip import BioClip
# Define fixtures if needed
@pytest.fixture
def sample_image_path():
return "path_to_sample_image.jpg"
@pytest.fixture
def clip_instance():
return BioClip(
"microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
# Basic tests for the BioClip class
def test_clip_initialization(clip_instance):
assert isinstance(clip_instance.model, torch.nn.Module)
assert hasattr(clip_instance, "model_path")
assert hasattr(clip_instance, "preprocess_train")
assert hasattr(clip_instance, "preprocess_val")
assert hasattr(clip_instance, "tokenizer")
assert hasattr(clip_instance, "device")
def test_clip_call_method(clip_instance, sample_image_path):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = clip_instance(sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
def test_clip_plot_image_with_metadata(
clip_instance, sample_image_path
):
metadata = {
"filename": "sample_image.jpg",
"top_probs": {"label1": 0.75, "label2": 0.65},
}
clip_instance.plot_image_with_metadata(
sample_image_path, metadata
)
# More test cases can be added to cover additional functionality and edge cases
# Parameterized tests for different image and label combinations
@pytest.mark.parametrize(
"image_path, labels",
[
("image1.jpg", ["label1", "label2"]),
("image2.jpg", ["label3", "label4"]),
# Add more image and label combinations
],
)
def test_clip_parameterized_calls(clip_instance, image_path, labels):
result = clip_instance(image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test image preprocessing
def test_clip_image_preprocessing(clip_instance, sample_image_path):
image = Image.open(sample_image_path)
processed_image = clip_instance.preprocess_val(image)
assert isinstance(processed_image, torch.Tensor)
# Test label tokenization
def test_clip_label_tokenization(clip_instance):
labels = ["label1", "label2"]
tokenized_labels = clip_instance.tokenizer(labels)
assert isinstance(tokenized_labels, torch.Tensor)
assert tokenized_labels.shape[0] == len(labels)
# More tests can be added to cover other methods and edge cases
# End-to-end tests with actual images and labels
def test_clip_end_to_end(clip_instance, sample_image_path):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = clip_instance(sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test label tokenization with long labels
def test_clip_long_labels(clip_instance):
labels = ["label" + str(i) for i in range(100)]
tokenized_labels = clip_instance.tokenizer(labels)
assert isinstance(tokenized_labels, torch.Tensor)
assert tokenized_labels.shape[0] == len(labels)
# Test handling of multiple image files
def test_clip_multiple_images(clip_instance, sample_image_path):
labels = ["label1", "label2"]
image_paths = [sample_image_path, "image2.jpg"]
results = clip_instance(image_paths, labels)
assert isinstance(results, list)
assert len(results) == len(image_paths)
for result in results:
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test model inference performance
def test_clip_inference_performance(
clip_instance, sample_image_path, benchmark
):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = benchmark(clip_instance, sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test different preprocessing pipelines
def test_clip_preprocessing_pipelines(
clip_instance, sample_image_path
):
labels = ["label1", "label2"]
image = Image.open(sample_image_path)
# Test preprocessing for training
processed_image_train = clip_instance.preprocess_train(image)
assert isinstance(processed_image_train, torch.Tensor)
# Test preprocessing for validation
processed_image_val = clip_instance.preprocess_val(image)
assert isinstance(processed_image_val, torch.Tensor)
# ...