parent
0d1fbb3e46
commit
0e49d09744
@ -1,56 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
from langchain import HuggingFaceHub
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
|
|
||||||
from swarms.models.llm import LLM
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLM(unittest.TestCase):
|
|
||||||
@patch.object(HuggingFaceHub, "__init__", return_value=None)
|
|
||||||
@patch.object(ChatOpenAI, "__init__", return_value=None)
|
|
||||||
def setUp(self, mock_hf_init, mock_openai_init):
|
|
||||||
self.llm_openai = LLM(openai_api_key="mock_openai_key")
|
|
||||||
self.llm_hf = LLM(
|
|
||||||
hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token"
|
|
||||||
)
|
|
||||||
self.prompt = "Who won the FIFA World Cup in 1998?"
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
self.assertEqual(
|
|
||||||
self.llm_openai.openai_api_key, "mock_openai_key"
|
|
||||||
)
|
|
||||||
self.assertEqual(self.llm_hf.hf_repo_id, "mock_repo_id")
|
|
||||||
self.assertEqual(self.llm_hf.hf_api_token, "mock_hf_token")
|
|
||||||
|
|
||||||
@patch.object(HuggingFaceHub, "run", return_value="France")
|
|
||||||
@patch.object(ChatOpenAI, "run", return_value="France")
|
|
||||||
def test_run(self, mock_hf_run, mock_openai_run):
|
|
||||||
result_openai = self.llm_openai.run(self.prompt)
|
|
||||||
mock_openai_run.assert_called_once()
|
|
||||||
self.assertEqual(result_openai, "France")
|
|
||||||
|
|
||||||
result_hf = self.llm_hf.run(self.prompt)
|
|
||||||
mock_hf_run.assert_called_once()
|
|
||||||
self.assertEqual(result_hf, "France")
|
|
||||||
|
|
||||||
def test_error_on_no_keys(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
LLM()
|
|
||||||
|
|
||||||
@patch.object(os, "environ", {})
|
|
||||||
def test_error_on_missing_hf_token(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
LLM(hf_repo_id="mock_repo_id")
|
|
||||||
|
|
||||||
@patch.dict(
|
|
||||||
os.environ, {"HUGGINGFACEHUB_API_TOKEN": "mock_hf_token"}
|
|
||||||
)
|
|
||||||
def test_hf_token_from_env(self):
|
|
||||||
llm = LLM(hf_repo_id="mock_repo_id")
|
|
||||||
self.assertEqual(llm.hf_api_token, "mock_hf_token")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
@ -1,454 +0,0 @@
|
|||||||
import os
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from openai import OpenAIError
|
|
||||||
from PIL import Image
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from swarms.models.dalle3 import Dalle3
|
|
||||||
|
|
||||||
|
|
||||||
# Mocking the OpenAI client to avoid making actual API calls during testing
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_openai_client():
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dalle3(mock_openai_client):
|
|
||||||
return Dalle3(client=mock_openai_client)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_success(dalle3, mock_openai_client):
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
mock_openai_client.images.generate.return_value = Mock(
|
|
||||||
data=[Mock(url=expected_img_url)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
img_url = dalle3(task)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert img_url == expected_img_url
|
|
||||||
mock_openai_client.images.generate.assert_called_once_with(
|
|
||||||
prompt=task, n=4
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
|
|
||||||
# Arrange
|
|
||||||
task = "Invalid task"
|
|
||||||
expected_error_message = "Error running Dalle3: API Error"
|
|
||||||
|
|
||||||
# Mocking OpenAIError
|
|
||||||
mock_openai_client.images.generate.side_effect = OpenAIError(
|
|
||||||
expected_error_message,
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(OpenAIError) as excinfo:
|
|
||||||
dalle3(task)
|
|
||||||
|
|
||||||
assert str(excinfo.value) == expected_error_message
|
|
||||||
mock_openai_client.images.generate.assert_called_once_with(
|
|
||||||
prompt=task, n=4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the error message is printed in red
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert colored(expected_error_message, "red") in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_success(dalle3, mock_openai_client):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
|
|
||||||
mock_openai_client.images.create_variation.return_value = Mock(
|
|
||||||
data=[Mock(url=expected_variation_url)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
variation_img_url = dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert variation_img_url == expected_variation_url
|
|
||||||
mock_openai_client.images.create_variation.assert_called_once()
|
|
||||||
_, kwargs = mock_openai_client.images.create_variation.call_args
|
|
||||||
assert kwargs["img"] is not None
|
|
||||||
assert kwargs["n"] == 4
|
|
||||||
assert kwargs["size"] == "1024x1024"
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_failure(
|
|
||||||
dalle3, mock_openai_client, capsys
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
expected_error_message = "Error running Dalle3: API Error"
|
|
||||||
|
|
||||||
# Mocking OpenAIError
|
|
||||||
mock_openai_client.images.create_variation.side_effect = (
|
|
||||||
OpenAIError(
|
|
||||||
expected_error_message,
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(OpenAIError) as excinfo:
|
|
||||||
dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
assert str(excinfo.value) == expected_error_message
|
|
||||||
mock_openai_client.images.create_variation.assert_called_once()
|
|
||||||
|
|
||||||
# Ensure the error message is printed in red
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert colored(expected_error_message, "red") in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_read_img():
|
|
||||||
# Arrange
|
|
||||||
img_path = "test_image.png"
|
|
||||||
img = Image.new("RGB", (512, 512))
|
|
||||||
|
|
||||||
# Save the image temporarily
|
|
||||||
img.save(img_path)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
dalle3 = Dalle3()
|
|
||||||
img_loaded = dalle3.read_img(img_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert isinstance(img_loaded, Image.Image)
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
os.remove(img_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_set_width_height():
|
|
||||||
# Arrange
|
|
||||||
img = Image.new("RGB", (512, 512))
|
|
||||||
width = 256
|
|
||||||
height = 256
|
|
||||||
|
|
||||||
# Act
|
|
||||||
dalle3 = Dalle3()
|
|
||||||
img_resized = dalle3.set_width_height(img, width, height)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert img_resized.size == (width, height)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_convert_to_bytesio():
|
|
||||||
# Arrange
|
|
||||||
img = Image.new("RGB", (512, 512))
|
|
||||||
expected_format = "PNG"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
dalle3 = Dalle3()
|
|
||||||
img_bytes = dalle3.convert_to_bytesio(img, format=expected_format)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert isinstance(img_bytes, bytes)
|
|
||||||
assert img_bytes.startswith(b"\x89PNG")
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_multiple_times(dalle3, mock_openai_client):
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
mock_openai_client.images.generate.return_value = Mock(
|
|
||||||
data=[Mock(url=expected_img_url)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
img_url1 = dalle3(task)
|
|
||||||
img_url2 = dalle3(task)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert img_url1 == expected_img_url
|
|
||||||
assert img_url2 == expected_img_url
|
|
||||||
assert mock_openai_client.images.generate.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_with_large_input(dalle3, mock_openai_client):
|
|
||||||
# Arrange
|
|
||||||
task = "A" * 2048 # Input longer than API's limit
|
|
||||||
expected_error_message = "Error running Dalle3: API Error"
|
|
||||||
mock_openai_client.images.generate.side_effect = OpenAIError(
|
|
||||||
expected_error_message,
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(OpenAIError) as excinfo:
|
|
||||||
dalle3(task)
|
|
||||||
|
|
||||||
assert str(excinfo.value) == expected_error_message
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_with_invalid_image_url(
|
|
||||||
dalle3, mock_openai_client
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://invalid-image-url.com"
|
|
||||||
expected_error_message = "Error running Dalle3: Invalid image URL"
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
assert str(excinfo.value) == expected_error_message
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_set_width_height_invalid_dimensions(dalle3):
|
|
||||||
# Arrange
|
|
||||||
img = dalle3.read_img("test_image.png")
|
|
||||||
width = 0
|
|
||||||
height = -1
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
dalle3.set_width_height(img, width, height)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_convert_to_bytesio_invalid_format(dalle3):
|
|
||||||
# Arrange
|
|
||||||
img = dalle3.read_img("test_image.png")
|
|
||||||
invalid_format = "invalid_format"
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
dalle3.convert_to_bytesio(img, format=invalid_format)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_with_retry(dalle3, mock_openai_client):
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
|
|
||||||
# Simulate a retry scenario
|
|
||||||
mock_openai_client.images.generate.side_effect = [
|
|
||||||
OpenAIError(
|
|
||||||
"Temporary error",
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
),
|
|
||||||
Mock(data=[Mock(url=expected_img_url)]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
img_url = dalle3(task)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert img_url == expected_img_url
|
|
||||||
assert mock_openai_client.images.generate.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_with_retry(
|
|
||||||
dalle3, mock_openai_client
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
|
|
||||||
|
|
||||||
# Simulate a retry scenario
|
|
||||||
mock_openai_client.images.create_variation.side_effect = [
|
|
||||||
OpenAIError(
|
|
||||||
"Temporary error",
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
),
|
|
||||||
Mock(data=[Mock(url=expected_variation_url)]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
variation_img_url = dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert variation_img_url == expected_variation_url
|
|
||||||
assert mock_openai_client.images.create_variation.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_exception_logging(
|
|
||||||
dalle3, mock_openai_client, capsys
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
expected_error_message = "Error running Dalle3: API Error"
|
|
||||||
|
|
||||||
# Mocking OpenAIError
|
|
||||||
mock_openai_client.images.generate.side_effect = OpenAIError(
|
|
||||||
expected_error_message,
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
with pytest.raises(OpenAIError):
|
|
||||||
dalle3(task)
|
|
||||||
|
|
||||||
# Assert that the error message is logged
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert expected_error_message in captured.err
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_exception_logging(
|
|
||||||
dalle3, mock_openai_client, capsys
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
expected_error_message = "Error running Dalle3: API Error"
|
|
||||||
|
|
||||||
# Mocking OpenAIError
|
|
||||||
mock_openai_client.images.create_variation.side_effect = (
|
|
||||||
OpenAIError(
|
|
||||||
expected_error_message,
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
with pytest.raises(OpenAIError):
|
|
||||||
dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
# Assert that the error message is logged
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert expected_error_message in captured.err
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_read_img_invalid_path(dalle3):
|
|
||||||
# Arrange
|
|
||||||
invalid_img_path = "invalid_image_path.png"
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
dalle3.read_img(invalid_img_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_no_api_key():
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
dalle3 = Dalle3(api_key=None)
|
|
||||||
expected_error_message = (
|
|
||||||
"Error running Dalle3: API Key is missing"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
dalle3(task)
|
|
||||||
|
|
||||||
assert str(excinfo.value) == expected_error_message
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_no_api_key():
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
dalle3 = Dalle3(api_key=None)
|
|
||||||
expected_error_message = (
|
|
||||||
"Error running Dalle3: API Key is missing"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
assert str(excinfo.value) == expected_error_message
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_with_retry_max_retries_exceeded(
|
|
||||||
dalle3, mock_openai_client
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
|
|
||||||
# Simulate max retries exceeded
|
|
||||||
mock_openai_client.images.generate.side_effect = OpenAIError(
|
|
||||||
"Temporary error",
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(OpenAIError) as excinfo:
|
|
||||||
dalle3(task)
|
|
||||||
|
|
||||||
assert "Retry limit exceeded" in str(excinfo.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_with_retry_max_retries_exceeded(
|
|
||||||
dalle3, mock_openai_client
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
|
|
||||||
# Simulate max retries exceeded
|
|
||||||
mock_openai_client.images.create_variation.side_effect = (
|
|
||||||
OpenAIError(
|
|
||||||
"Temporary error",
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act and assert
|
|
||||||
with pytest.raises(OpenAIError) as excinfo:
|
|
||||||
dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
assert "Retry limit exceeded" in str(excinfo.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
|
|
||||||
# Arrange
|
|
||||||
task = "A painting of a dog"
|
|
||||||
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
|
|
||||||
# Simulate success after a retry
|
|
||||||
mock_openai_client.images.generate.side_effect = [
|
|
||||||
OpenAIError(
|
|
||||||
"Temporary error",
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
),
|
|
||||||
Mock(data=[Mock(url=expected_img_url)]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
img_url = dalle3(task)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert img_url == expected_img_url
|
|
||||||
assert mock_openai_client.images.generate.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_dalle3_create_variations_retry_with_success(
|
|
||||||
dalle3, mock_openai_client
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
|
|
||||||
|
|
||||||
# Simulate success after a retry
|
|
||||||
mock_openai_client.images.create_variation.side_effect = [
|
|
||||||
OpenAIError(
|
|
||||||
"Temporary error",
|
|
||||||
http_status=500,
|
|
||||||
error="Internal Server Error",
|
|
||||||
),
|
|
||||||
Mock(data=[Mock(url=expected_variation_url)]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
variation_img_url = dalle3.create_variations(img_url)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert variation_img_url == expected_variation_url
|
|
||||||
assert mock_openai_client.images.create_variation.call_count == 2
|
|
Loading…
Reference in new issue