You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
5.0 KiB
170 lines
5.0 KiB
# Import necessary modules and define fixtures if needed
|
|
import pytest
|
|
import torch
|
|
from PIL import Image
|
|
from swarms.models.bioclip import BioClip
|
|
|
|
|
|
# Define fixtures if needed
|
|
@pytest.fixture
|
|
def sample_image_path():
|
|
return "path_to_sample_image.jpg"
|
|
|
|
|
|
@pytest.fixture
|
|
def clip_instance():
|
|
return BioClip(
|
|
"microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
|
)
|
|
|
|
|
|
# Basic tests for the BioClip class
|
|
def test_clip_initialization(clip_instance):
|
|
assert isinstance(clip_instance.model, torch.nn.Module)
|
|
assert hasattr(clip_instance, "model_path")
|
|
assert hasattr(clip_instance, "preprocess_train")
|
|
assert hasattr(clip_instance, "preprocess_val")
|
|
assert hasattr(clip_instance, "tokenizer")
|
|
assert hasattr(clip_instance, "device")
|
|
|
|
|
|
def test_clip_call_method(clip_instance, sample_image_path):
|
|
labels = [
|
|
"adenocarcinoma histopathology",
|
|
"brain MRI",
|
|
"covid line chart",
|
|
"squamous cell carcinoma histopathology",
|
|
"immunohistochemistry histopathology",
|
|
"bone X-ray",
|
|
"chest X-ray",
|
|
"pie chart",
|
|
"hematoxylin and eosin histopathology",
|
|
]
|
|
result = clip_instance(sample_image_path, labels)
|
|
assert isinstance(result, dict)
|
|
assert len(result) == len(labels)
|
|
|
|
|
|
def test_clip_plot_image_with_metadata(
|
|
clip_instance, sample_image_path
|
|
):
|
|
metadata = {
|
|
"filename": "sample_image.jpg",
|
|
"top_probs": {"label1": 0.75, "label2": 0.65},
|
|
}
|
|
clip_instance.plot_image_with_metadata(
|
|
sample_image_path, metadata
|
|
)
|
|
|
|
|
|
# More test cases can be added to cover additional functionality and edge cases
|
|
|
|
|
|
# Parameterized tests for different image and label combinations
|
|
@pytest.mark.parametrize(
|
|
"image_path, labels",
|
|
[
|
|
("image1.jpg", ["label1", "label2"]),
|
|
("image2.jpg", ["label3", "label4"]),
|
|
# Add more image and label combinations
|
|
],
|
|
)
|
|
def test_clip_parameterized_calls(clip_instance, image_path, labels):
|
|
result = clip_instance(image_path, labels)
|
|
assert isinstance(result, dict)
|
|
assert len(result) == len(labels)
|
|
|
|
|
|
# Test image preprocessing
|
|
def test_clip_image_preprocessing(clip_instance, sample_image_path):
|
|
image = Image.open(sample_image_path)
|
|
processed_image = clip_instance.preprocess_val(image)
|
|
assert isinstance(processed_image, torch.Tensor)
|
|
|
|
|
|
# Test label tokenization
|
|
def test_clip_label_tokenization(clip_instance):
|
|
labels = ["label1", "label2"]
|
|
tokenized_labels = clip_instance.tokenizer(labels)
|
|
assert isinstance(tokenized_labels, torch.Tensor)
|
|
assert tokenized_labels.shape[0] == len(labels)
|
|
|
|
|
|
# More tests can be added to cover other methods and edge cases
|
|
|
|
|
|
# End-to-end tests with actual images and labels
|
|
def test_clip_end_to_end(clip_instance, sample_image_path):
|
|
labels = [
|
|
"adenocarcinoma histopathology",
|
|
"brain MRI",
|
|
"covid line chart",
|
|
"squamous cell carcinoma histopathology",
|
|
"immunohistochemistry histopathology",
|
|
"bone X-ray",
|
|
"chest X-ray",
|
|
"pie chart",
|
|
"hematoxylin and eosin histopathology",
|
|
]
|
|
result = clip_instance(sample_image_path, labels)
|
|
assert isinstance(result, dict)
|
|
assert len(result) == len(labels)
|
|
|
|
|
|
# Test label tokenization with long labels
|
|
def test_clip_long_labels(clip_instance):
|
|
labels = ["label" + str(i) for i in range(100)]
|
|
tokenized_labels = clip_instance.tokenizer(labels)
|
|
assert isinstance(tokenized_labels, torch.Tensor)
|
|
assert tokenized_labels.shape[0] == len(labels)
|
|
|
|
|
|
# Test handling of multiple image files
|
|
def test_clip_multiple_images(clip_instance, sample_image_path):
|
|
labels = ["label1", "label2"]
|
|
image_paths = [sample_image_path, "image2.jpg"]
|
|
results = clip_instance(image_paths, labels)
|
|
assert isinstance(results, list)
|
|
assert len(results) == len(image_paths)
|
|
for result in results:
|
|
assert isinstance(result, dict)
|
|
assert len(result) == len(labels)
|
|
|
|
|
|
# Test model inference performance
|
|
def test_clip_inference_performance(
|
|
clip_instance, sample_image_path, benchmark
|
|
):
|
|
labels = [
|
|
"adenocarcinoma histopathology",
|
|
"brain MRI",
|
|
"covid line chart",
|
|
"squamous cell carcinoma histopathology",
|
|
"immunohistochemistry histopathology",
|
|
"bone X-ray",
|
|
"chest X-ray",
|
|
"pie chart",
|
|
"hematoxylin and eosin histopathology",
|
|
]
|
|
result = benchmark(clip_instance, sample_image_path, labels)
|
|
assert isinstance(result, dict)
|
|
assert len(result) == len(labels)
|
|
|
|
|
|
# Test different preprocessing pipelines
|
|
def test_clip_preprocessing_pipelines(
|
|
clip_instance, sample_image_path
|
|
):
|
|
image = Image.open(sample_image_path)
|
|
|
|
# Test preprocessing for training
|
|
processed_image_train = clip_instance.preprocess_train(image)
|
|
assert isinstance(processed_image_train, torch.Tensor)
|
|
|
|
# Test preprocessing for validation
|
|
processed_image_val = clip_instance.preprocess_val(image)
|
|
assert isinstance(processed_image_val, torch.Tensor)
|
|
|
|
|
|
# ...
|