# Import necessary modules and define fixtures if needed import os import pytest import torch from PIL import Image from swarms.models.bioclip import BioClip # Define fixtures if needed @pytest.fixture def sample_image_path(): return "path_to_sample_image.jpg" @pytest.fixture def clip_instance(): return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") # Basic tests for the BioClip class def test_clip_initialization(clip_instance): assert isinstance(clip_instance.model, torch.nn.Module) assert hasattr(clip_instance, "model_path") assert hasattr(clip_instance, "preprocess_train") assert hasattr(clip_instance, "preprocess_val") assert hasattr(clip_instance, "tokenizer") assert hasattr(clip_instance, "device") def test_clip_call_method(clip_instance, sample_image_path): labels = [ "adenocarcinoma histopathology", "brain MRI", "covid line chart", "squamous cell carcinoma histopathology", "immunohistochemistry histopathology", "bone X-ray", "chest X-ray", "pie chart", "hematoxylin and eosin histopathology", ] result = clip_instance(sample_image_path, labels) assert isinstance(result, dict) assert len(result) == len(labels) def test_clip_plot_image_with_metadata(clip_instance, sample_image_path): metadata = { "filename": "sample_image.jpg", "top_probs": {"label1": 0.75, "label2": 0.65}, } clip_instance.plot_image_with_metadata(sample_image_path, metadata) # More test cases can be added to cover additional functionality and edge cases # Parameterized tests for different image and label combinations @pytest.mark.parametrize( "image_path, labels", [ ("image1.jpg", ["label1", "label2"]), ("image2.jpg", ["label3", "label4"]), # Add more image and label combinations ], ) def test_clip_parameterized_calls(clip_instance, image_path, labels): result = clip_instance(image_path, labels) assert isinstance(result, dict) assert len(result) == len(labels) # Test image preprocessing def test_clip_image_preprocessing(clip_instance, sample_image_path): image = Image.open(sample_image_path) processed_image = clip_instance.preprocess_val(image) assert isinstance(processed_image, torch.Tensor) # Test label tokenization def test_clip_label_tokenization(clip_instance): labels = ["label1", "label2"] tokenized_labels = clip_instance.tokenizer(labels) assert isinstance(tokenized_labels, torch.Tensor) assert tokenized_labels.shape[0] == len(labels) # More tests can be added to cover other methods and edge cases # End-to-end tests with actual images and labels def test_clip_end_to_end(clip_instance, sample_image_path): labels = [ "adenocarcinoma histopathology", "brain MRI", "covid line chart", "squamous cell carcinoma histopathology", "immunohistochemistry histopathology", "bone X-ray", "chest X-ray", "pie chart", "hematoxylin and eosin histopathology", ] result = clip_instance(sample_image_path, labels) assert isinstance(result, dict) assert len(result) == len(labels) # Test label tokenization with long labels def test_clip_long_labels(clip_instance): labels = ["label" + str(i) for i in range(100)] tokenized_labels = clip_instance.tokenizer(labels) assert isinstance(tokenized_labels, torch.Tensor) assert tokenized_labels.shape[0] == len(labels) # Test handling of multiple image files def test_clip_multiple_images(clip_instance, sample_image_path): labels = ["label1", "label2"] image_paths = [sample_image_path, "image2.jpg"] results = clip_instance(image_paths, labels) assert isinstance(results, list) assert len(results) == len(image_paths) for result in results: assert isinstance(result, dict) assert len(result) == len(labels) # Test model inference performance def test_clip_inference_performance( clip_instance, sample_image_path, benchmark ): labels = [ "adenocarcinoma histopathology", "brain MRI", "covid line chart", "squamous cell carcinoma histopathology", "immunohistochemistry histopathology", "bone X-ray", "chest X-ray", "pie chart", "hematoxylin and eosin histopathology", ] result = benchmark(clip_instance, sample_image_path, labels) assert isinstance(result, dict) assert len(result) == len(labels) # Test different preprocessing pipelines def test_clip_preprocessing_pipelines(clip_instance, sample_image_path): labels = ["label1", "label2"] image = Image.open(sample_image_path) # Test preprocessing for training processed_image_train = clip_instance.preprocess_train(image) assert isinstance(processed_image_train, torch.Tensor) # Test preprocessing for validation processed_image_val = clip_instance.preprocess_val(image) assert isinstance(processed_image_val, torch.Tensor) # ...