You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
308 lines
8.8 KiB
308 lines
8.8 KiB
6 months ago
|
from unittest.mock import Mock, patch
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from swarms.models.gemini import Gemini
|
||
|
|
||
|
|
||
|
# Define test fixtures
|
||
|
@pytest.fixture
|
||
|
def mock_gemini_api_key(monkeypatch):
|
||
|
monkeypatch.setenv("GEMINI_API_KEY", "mocked-api-key")
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_genai_model():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
# Test initialization of Gemini
|
||
|
def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model):
|
||
|
model = Gemini()
|
||
|
assert model.model_name == "gemini-pro"
|
||
|
assert model.gemini_api_key == "mocked-api-key"
|
||
|
assert model.model is mock_genai_model
|
||
|
|
||
|
|
||
|
def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model):
|
||
|
model = Gemini(
|
||
|
model_name="custom-model", gemini_api_key="custom-api-key"
|
||
|
)
|
||
|
assert model.model_name == "custom-model"
|
||
|
assert model.gemini_api_key == "custom-api-key"
|
||
|
assert model.model is mock_genai_model
|
||
|
|
||
|
|
||
|
# Test Gemini run method
|
||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
def test_gemini_run_with_img(
|
||
|
mock_generate_content,
|
||
|
mock_process_img,
|
||
|
mock_gemini_api_key,
|
||
|
mock_genai_model,
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A cat"
|
||
|
img = "cat.png"
|
||
|
response_mock = Mock(text="Generated response")
|
||
|
mock_generate_content.return_value = response_mock
|
||
|
mock_process_img.return_value = "Processed image"
|
||
|
|
||
|
response = model.run(task=task, img=img)
|
||
|
|
||
|
assert response == "Generated response"
|
||
|
mock_generate_content.assert_called_with(
|
||
|
content=[task, "Processed image"]
|
||
|
)
|
||
|
mock_process_img.assert_called_with(img=img)
|
||
|
|
||
|
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
def test_gemini_run_without_img(
|
||
|
mock_generate_content, mock_gemini_api_key, mock_genai_model
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A cat"
|
||
|
response_mock = Mock(text="Generated response")
|
||
|
mock_generate_content.return_value = response_mock
|
||
|
|
||
|
response = model.run(task=task)
|
||
|
|
||
|
assert response == "Generated response"
|
||
|
mock_generate_content.assert_called_with(task=task)
|
||
|
|
||
|
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
def test_gemini_run_exception(
|
||
|
mock_generate_content, mock_gemini_api_key, mock_genai_model
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A cat"
|
||
|
mock_generate_content.side_effect = Exception("Test exception")
|
||
|
|
||
|
response = model.run(task=task)
|
||
|
|
||
|
assert response is None
|
||
|
|
||
|
|
||
|
# Test Gemini process_img method
|
||
|
def test_gemini_process_img(mock_gemini_api_key, mock_genai_model):
|
||
|
model = Gemini(gemini_api_key="custom-api-key")
|
||
|
img = "cat.png"
|
||
|
img_data = b"Mocked image data"
|
||
|
|
||
|
with patch("builtins.open", create=True) as open_mock:
|
||
|
open_mock.return_value.__enter__.return_value.read.return_value = (
|
||
|
img_data
|
||
|
)
|
||
|
|
||
|
processed_img = model.process_img(img)
|
||
|
|
||
|
assert processed_img == [{"mime_type": "image/png", "data": img_data}]
|
||
|
open_mock.assert_called_with(img, "rb")
|
||
|
|
||
|
|
||
|
# Test Gemini initialization with missing API key
|
||
|
def test_gemini_init_missing_api_key():
|
||
|
with pytest.raises(
|
||
|
ValueError, match="Please provide a Gemini API key"
|
||
|
):
|
||
|
Gemini(gemini_api_key=None)
|
||
|
|
||
|
|
||
|
# Test Gemini initialization with missing model name
|
||
|
def test_gemini_init_missing_model_name():
|
||
|
with pytest.raises(ValueError, match="Please provide a model name"):
|
||
|
Gemini(model_name=None)
|
||
|
|
||
|
|
||
|
# Test Gemini run method with empty task
|
||
|
def test_gemini_run_empty_task(mock_gemini_api_key, mock_genai_model):
|
||
|
model = Gemini()
|
||
|
task = ""
|
||
|
response = model.run(task=task)
|
||
|
assert response is None
|
||
|
|
||
|
|
||
|
# Test Gemini run method with empty image
|
||
|
def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model):
|
||
|
model = Gemini()
|
||
|
task = "A cat"
|
||
|
img = ""
|
||
|
response = model.run(task=task, img=img)
|
||
|
assert response is None
|
||
|
|
||
|
|
||
|
# Test Gemini process_img method with missing image
|
||
|
def test_gemini_process_img_missing_image(
|
||
|
mock_gemini_api_key, mock_genai_model
|
||
|
):
|
||
|
model = Gemini()
|
||
|
img = None
|
||
|
with pytest.raises(
|
||
|
ValueError, match="Please provide an image to process"
|
||
|
):
|
||
|
model.process_img(img=img)
|
||
|
|
||
|
|
||
|
# Test Gemini process_img method with missing image type
|
||
|
def test_gemini_process_img_missing_image_type(
|
||
|
mock_gemini_api_key, mock_genai_model
|
||
|
):
|
||
|
model = Gemini()
|
||
|
img = "cat.png"
|
||
|
with pytest.raises(ValueError, match="Please provide the image type"):
|
||
|
model.process_img(img=img, type=None)
|
||
|
|
||
|
|
||
|
# Test Gemini process_img method with missing Gemini API key
|
||
|
def test_gemini_process_img_missing_api_key(mock_genai_model):
|
||
|
model = Gemini(gemini_api_key=None)
|
||
|
img = "cat.png"
|
||
|
with pytest.raises(
|
||
|
ValueError, match="Please provide a Gemini API key"
|
||
|
):
|
||
|
model.process_img(img=img, type="image/png")
|
||
|
|
||
|
|
||
|
# Test Gemini run method with mocked image processing
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||
|
def test_gemini_run_mock_img_processing(
|
||
|
mock_process_img,
|
||
|
mock_generate_content,
|
||
|
mock_gemini_api_key,
|
||
|
mock_genai_model,
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A cat"
|
||
|
img = "cat.png"
|
||
|
response_mock = Mock(text="Generated response")
|
||
|
mock_generate_content.return_value = response_mock
|
||
|
mock_process_img.return_value = "Processed image"
|
||
|
|
||
|
response = model.run(task=task, img=img)
|
||
|
|
||
|
assert response == "Generated response"
|
||
|
mock_generate_content.assert_called_with(
|
||
|
content=[task, "Processed image"]
|
||
|
)
|
||
|
mock_process_img.assert_called_with(img=img)
|
||
|
|
||
|
|
||
|
# Test Gemini run method with mocked image processing and exception
|
||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
def test_gemini_run_mock_img_processing_exception(
|
||
|
mock_generate_content,
|
||
|
mock_process_img,
|
||
|
mock_gemini_api_key,
|
||
|
mock_genai_model,
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A cat"
|
||
|
img = "cat.png"
|
||
|
mock_process_img.side_effect = Exception("Test exception")
|
||
|
|
||
|
response = model.run(task=task, img=img)
|
||
|
|
||
|
assert response is None
|
||
|
mock_generate_content.assert_not_called()
|
||
|
mock_process_img.assert_called_with(img=img)
|
||
|
|
||
|
|
||
|
# Test Gemini run method with mocked image processing and different exception
|
||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
def test_gemini_run_mock_img_processing_different_exception(
|
||
|
mock_generate_content,
|
||
|
mock_process_img,
|
||
|
mock_gemini_api_key,
|
||
|
mock_genai_model,
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A dog"
|
||
|
img = "dog.png"
|
||
|
mock_process_img.side_effect = ValueError("Test exception")
|
||
|
|
||
|
with pytest.raises(ValueError):
|
||
|
model.run(task=task, img=img)
|
||
|
|
||
|
mock_generate_content.assert_not_called()
|
||
|
mock_process_img.assert_called_with(img=img)
|
||
|
|
||
|
|
||
|
# Test Gemini run method with mocked image processing and no exception
|
||
|
@patch("swarms.models.gemini.Gemini.process_img")
|
||
|
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||
|
def test_gemini_run_mock_img_processing_no_exception(
|
||
|
mock_generate_content,
|
||
|
mock_process_img,
|
||
|
mock_gemini_api_key,
|
||
|
mock_genai_model,
|
||
|
):
|
||
|
model = Gemini()
|
||
|
task = "A bird"
|
||
|
img = "bird.png"
|
||
|
mock_generate_content.return_value = "A bird is flying"
|
||
|
|
||
|
response = model.run(task=task, img=img)
|
||
|
|
||
|
assert response == "A bird is flying"
|
||
|
mock_generate_content.assert_called_once()
|
||
|
mock_process_img.assert_called_with(img=img)
|
||
|
|
||
|
|
||
|
# Test Gemini chat method
|
||
|
@patch("swarms.models.gemini.Gemini.chat")
|
||
|
def test_gemini_chat(mock_chat):
|
||
|
model = Gemini()
|
||
|
mock_chat.return_value = "Hello, Gemini!"
|
||
|
|
||
|
response = model.chat("Hello, Gemini!")
|
||
|
|
||
|
assert response == "Hello, Gemini!"
|
||
|
mock_chat.assert_called_once()
|
||
|
|
||
|
|
||
|
# Test Gemini list_models method
|
||
|
@patch("swarms.models.gemini.Gemini.list_models")
|
||
|
def test_gemini_list_models(mock_list_models):
|
||
|
model = Gemini()
|
||
|
mock_list_models.return_value = ["model1", "model2"]
|
||
|
|
||
|
response = model.list_models()
|
||
|
|
||
|
assert response == ["model1", "model2"]
|
||
|
mock_list_models.assert_called_once()
|
||
|
|
||
|
|
||
|
# Test Gemini stream_tokens method
|
||
|
@patch("swarms.models.gemini.Gemini.stream_tokens")
|
||
|
def test_gemini_stream_tokens(mock_stream_tokens):
|
||
|
model = Gemini()
|
||
|
mock_stream_tokens.return_value = ["token1", "token2"]
|
||
|
|
||
|
response = model.stream_tokens()
|
||
|
|
||
|
assert response == ["token1", "token2"]
|
||
|
mock_stream_tokens.assert_called_once()
|
||
|
|
||
|
|
||
|
# Test Gemini process_img_pil method
|
||
|
@patch("swarms.models.gemini.Gemini.process_img_pil")
|
||
|
def test_gemini_process_img_pil(mock_process_img_pil):
|
||
|
model = Gemini()
|
||
|
img = "bird.png"
|
||
|
mock_process_img_pil.return_value = "processed image"
|
||
|
|
||
|
response = model.process_img_pil(img)
|
||
|
|
||
|
assert response == "processed image"
|
||
|
mock_process_img_pil.assert_called_with(img)
|
||
|
|
||
|
|
||
|
# Repeat the above tests for different scenarios or different methods in your Gemini class
|
||
|
# until you have 15 tests in total.
|