You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_gemini.py

308 lines
8.8 KiB

6 months ago
from unittest.mock import Mock, patch
import pytest
from swarms.models.gemini import Gemini
# Define test fixtures
@pytest.fixture
def mock_gemini_api_key(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "mocked-api-key")
@pytest.fixture
def mock_genai_model():
return Mock()
# Test initialization of Gemini
def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model):
model = Gemini()
assert model.model_name == "gemini-pro"
assert model.gemini_api_key == "mocked-api-key"
assert model.model is mock_genai_model
def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model):
model = Gemini(
model_name="custom-model", gemini_api_key="custom-api-key"
)
assert model.model_name == "custom-model"
assert model.gemini_api_key == "custom-api-key"
assert model.model is mock_genai_model
# Test Gemini run method
@patch("swarms.models.gemini.Gemini.process_img")
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_with_img(
mock_generate_content,
mock_process_img,
mock_gemini_api_key,
mock_genai_model,
):
model = Gemini()
task = "A cat"
img = "cat.png"
response_mock = Mock(text="Generated response")
mock_generate_content.return_value = response_mock
mock_process_img.return_value = "Processed image"
response = model.run(task=task, img=img)
assert response == "Generated response"
mock_generate_content.assert_called_with(
content=[task, "Processed image"]
)
mock_process_img.assert_called_with(img=img)
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_without_img(
mock_generate_content, mock_gemini_api_key, mock_genai_model
):
model = Gemini()
task = "A cat"
response_mock = Mock(text="Generated response")
mock_generate_content.return_value = response_mock
response = model.run(task=task)
assert response == "Generated response"
mock_generate_content.assert_called_with(task=task)
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_exception(
mock_generate_content, mock_gemini_api_key, mock_genai_model
):
model = Gemini()
task = "A cat"
mock_generate_content.side_effect = Exception("Test exception")
response = model.run(task=task)
assert response is None
# Test Gemini process_img method
def test_gemini_process_img(mock_gemini_api_key, mock_genai_model):
model = Gemini(gemini_api_key="custom-api-key")
img = "cat.png"
img_data = b"Mocked image data"
with patch("builtins.open", create=True) as open_mock:
open_mock.return_value.__enter__.return_value.read.return_value = (
img_data
)
processed_img = model.process_img(img)
assert processed_img == [{"mime_type": "image/png", "data": img_data}]
open_mock.assert_called_with(img, "rb")
# Test Gemini initialization with missing API key
def test_gemini_init_missing_api_key():
with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
Gemini(gemini_api_key=None)
# Test Gemini initialization with missing model name
def test_gemini_init_missing_model_name():
with pytest.raises(ValueError, match="Please provide a model name"):
Gemini(model_name=None)
# Test Gemini run method with empty task
def test_gemini_run_empty_task(mock_gemini_api_key, mock_genai_model):
model = Gemini()
task = ""
response = model.run(task=task)
assert response is None
# Test Gemini run method with empty image
def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model):
model = Gemini()
task = "A cat"
img = ""
response = model.run(task=task, img=img)
assert response is None
# Test Gemini process_img method with missing image
def test_gemini_process_img_missing_image(
mock_gemini_api_key, mock_genai_model
):
model = Gemini()
img = None
with pytest.raises(
ValueError, match="Please provide an image to process"
):
model.process_img(img=img)
# Test Gemini process_img method with missing image type
def test_gemini_process_img_missing_image_type(
mock_gemini_api_key, mock_genai_model
):
model = Gemini()
img = "cat.png"
with pytest.raises(ValueError, match="Please provide the image type"):
model.process_img(img=img, type=None)
# Test Gemini process_img method with missing Gemini API key
def test_gemini_process_img_missing_api_key(mock_genai_model):
model = Gemini(gemini_api_key=None)
img = "cat.png"
with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
model.process_img(img=img, type="image/png")
# Test Gemini run method with mocked image processing
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
@patch("swarms.models.gemini.Gemini.process_img")
def test_gemini_run_mock_img_processing(
mock_process_img,
mock_generate_content,
mock_gemini_api_key,
mock_genai_model,
):
model = Gemini()
task = "A cat"
img = "cat.png"
response_mock = Mock(text="Generated response")
mock_generate_content.return_value = response_mock
mock_process_img.return_value = "Processed image"
response = model.run(task=task, img=img)
assert response == "Generated response"
mock_generate_content.assert_called_with(
content=[task, "Processed image"]
)
mock_process_img.assert_called_with(img=img)
# Test Gemini run method with mocked image processing and exception
@patch("swarms.models.gemini.Gemini.process_img")
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_mock_img_processing_exception(
mock_generate_content,
mock_process_img,
mock_gemini_api_key,
mock_genai_model,
):
model = Gemini()
task = "A cat"
img = "cat.png"
mock_process_img.side_effect = Exception("Test exception")
response = model.run(task=task, img=img)
assert response is None
mock_generate_content.assert_not_called()
mock_process_img.assert_called_with(img=img)
# Test Gemini run method with mocked image processing and different exception
@patch("swarms.models.gemini.Gemini.process_img")
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_mock_img_processing_different_exception(
mock_generate_content,
mock_process_img,
mock_gemini_api_key,
mock_genai_model,
):
model = Gemini()
task = "A dog"
img = "dog.png"
mock_process_img.side_effect = ValueError("Test exception")
with pytest.raises(ValueError):
model.run(task=task, img=img)
mock_generate_content.assert_not_called()
mock_process_img.assert_called_with(img=img)
# Test Gemini run method with mocked image processing and no exception
@patch("swarms.models.gemini.Gemini.process_img")
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
def test_gemini_run_mock_img_processing_no_exception(
mock_generate_content,
mock_process_img,
mock_gemini_api_key,
mock_genai_model,
):
model = Gemini()
task = "A bird"
img = "bird.png"
mock_generate_content.return_value = "A bird is flying"
response = model.run(task=task, img=img)
assert response == "A bird is flying"
mock_generate_content.assert_called_once()
mock_process_img.assert_called_with(img=img)
# Test Gemini chat method
@patch("swarms.models.gemini.Gemini.chat")
def test_gemini_chat(mock_chat):
model = Gemini()
mock_chat.return_value = "Hello, Gemini!"
response = model.chat("Hello, Gemini!")
assert response == "Hello, Gemini!"
mock_chat.assert_called_once()
# Test Gemini list_models method
@patch("swarms.models.gemini.Gemini.list_models")
def test_gemini_list_models(mock_list_models):
model = Gemini()
mock_list_models.return_value = ["model1", "model2"]
response = model.list_models()
assert response == ["model1", "model2"]
mock_list_models.assert_called_once()
# Test Gemini stream_tokens method
@patch("swarms.models.gemini.Gemini.stream_tokens")
def test_gemini_stream_tokens(mock_stream_tokens):
model = Gemini()
mock_stream_tokens.return_value = ["token1", "token2"]
response = model.stream_tokens()
assert response == ["token1", "token2"]
mock_stream_tokens.assert_called_once()
# Test Gemini process_img_pil method
@patch("swarms.models.gemini.Gemini.process_img_pil")
def test_gemini_process_img_pil(mock_process_img_pil):
model = Gemini()
img = "bird.png"
mock_process_img_pil.return_value = "processed image"
response = model.process_img_pil(img)
assert response == "processed image"
mock_process_img_pil.assert_called_with(img)
# Repeat the above tests for different scenarios or different methods in your Gemini class
# until you have 15 tests in total.