import os
from unittest.mock import Mock

import pytest
from openai import OpenAIError
from PIL import Image
from termcolor import colored

from playground.models.dalle3 import Dalle3


# Mocking the OpenAI client to avoid making actual API calls during testing
@pytest.fixture
def mock_openai_client():
    return Mock()


@pytest.fixture
def dalle3(mock_openai_client):
    return Dalle3(client=mock_openai_client)


def test_dalle3_call_success(dalle3, mock_openai_client):
    # Arrange
    task = "A painting of a dog"
    expected_img_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    )
    mock_openai_client.images.generate.return_value = Mock(
        data=[Mock(url=expected_img_url)]
    )

    # Act
    img_url = dalle3(task)

    # Assert
    assert img_url == expected_img_url
    mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4)


def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
    # Arrange
    task = "Invalid task"
    expected_error_message = "Error running Dalle3: API Error"

    # Mocking OpenAIError
    mock_openai_client.images.generate.side_effect = OpenAIError(
        expected_error_message, http_status=500, error="Internal Server Error"
    )

    # Act and assert
    with pytest.raises(OpenAIError) as excinfo:
        dalle3(task)

    assert str(excinfo.value) == expected_error_message
    mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4)

    # Ensure the error message is printed in red
    captured = capsys.readouterr()
    assert colored(expected_error_message, "red") in captured.out


def test_dalle3_create_variations_success(dalle3, mock_openai_client):
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    expected_variation_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
    )
    mock_openai_client.images.create_variation.return_value = Mock(
        data=[Mock(url=expected_variation_url)]
    )

    # Act
    variation_img_url = dalle3.create_variations(img_url)

    # Assert
    assert variation_img_url == expected_variation_url
    mock_openai_client.images.create_variation.assert_called_once()
    _, kwargs = mock_openai_client.images.create_variation.call_args
    assert kwargs["img"] is not None
    assert kwargs["n"] == 4
    assert kwargs["size"] == "1024x1024"


def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys):
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    expected_error_message = "Error running Dalle3: API Error"

    # Mocking OpenAIError
    mock_openai_client.images.create_variation.side_effect = OpenAIError(
        expected_error_message, http_status=500, error="Internal Server Error"
    )

    # Act and assert
    with pytest.raises(OpenAIError) as excinfo:
        dalle3.create_variations(img_url)

    assert str(excinfo.value) == expected_error_message
    mock_openai_client.images.create_variation.assert_called_once()

    # Ensure the error message is printed in red
    captured = capsys.readouterr()
    assert colored(expected_error_message, "red") in captured.out


def test_dalle3_read_img():
    # Arrange
    img_path = "test_image.png"
    img = Image.new("RGB", (512, 512))

    # Save the image temporarily
    img.save(img_path)

    # Act
    dalle3 = Dalle3()
    img_loaded = dalle3.read_img(img_path)

    # Assert
    assert isinstance(img_loaded, Image.Image)

    # Clean up
    os.remove(img_path)


def test_dalle3_set_width_height():
    # Arrange
    img = Image.new("RGB", (512, 512))
    width = 256
    height = 256

    # Act
    dalle3 = Dalle3()
    img_resized = dalle3.set_width_height(img, width, height)

    # Assert
    assert img_resized.size == (width, height)


def test_dalle3_convert_to_bytesio():
    # Arrange
    img = Image.new("RGB", (512, 512))
    expected_format = "PNG"

    # Act
    dalle3 = Dalle3()
    img_bytes = dalle3.convert_to_bytesio(img, format=expected_format)

    # Assert
    assert isinstance(img_bytes, bytes)
    assert img_bytes.startswith(b"\x89PNG")


def test_dalle3_call_multiple_times(dalle3, mock_openai_client):
    # Arrange
    task = "A painting of a dog"
    expected_img_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    )
    mock_openai_client.images.generate.return_value = Mock(
        data=[Mock(url=expected_img_url)]
    )

    # Act
    img_url1 = dalle3(task)
    img_url2 = dalle3(task)

    # Assert
    assert img_url1 == expected_img_url
    assert img_url2 == expected_img_url
    assert mock_openai_client.images.generate.call_count == 2


def test_dalle3_call_with_large_input(dalle3, mock_openai_client):
    # Arrange
    task = "A" * 2048  # Input longer than API's limit
    expected_error_message = "Error running Dalle3: API Error"
    mock_openai_client.images.generate.side_effect = OpenAIError(
        expected_error_message, http_status=500, error="Internal Server Error"
    )

    # Act and assert
    with pytest.raises(OpenAIError) as excinfo:
        dalle3(task)

    assert str(excinfo.value) == expected_error_message


def test_dalle3_create_variations_with_invalid_image_url(dalle3, mock_openai_client):
    # Arrange
    img_url = "https://invalid-image-url.com"
    expected_error_message = "Error running Dalle3: Invalid image URL"

    # Act and assert
    with pytest.raises(ValueError) as excinfo:
        dalle3.create_variations(img_url)

    assert str(excinfo.value) == expected_error_message


def test_dalle3_set_width_height_invalid_dimensions(dalle3):
    # Arrange
    img = dalle3.read_img("test_image.png")
    width = 0
    height = -1

    # Act and assert
    with pytest.raises(ValueError):
        dalle3.set_width_height(img, width, height)


def test_dalle3_convert_to_bytesio_invalid_format(dalle3):
    # Arrange
    img = dalle3.read_img("test_image.png")
    invalid_format = "invalid_format"

    # Act and assert
    with pytest.raises(ValueError):
        dalle3.convert_to_bytesio(img, format=invalid_format)


def test_dalle3_call_with_retry(dalle3, mock_openai_client):
    # Arrange
    task = "A painting of a dog"
    expected_img_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    )

    # Simulate a retry scenario
    mock_openai_client.images.generate.side_effect = [
        OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
        Mock(data=[Mock(url=expected_img_url)]),
    ]

    # Act
    img_url = dalle3(task)

    # Assert
    assert img_url == expected_img_url
    assert mock_openai_client.images.generate.call_count == 2


def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client):
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    expected_variation_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
    )

    # Simulate a retry scenario
    mock_openai_client.images.create_variation.side_effect = [
        OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
        Mock(data=[Mock(url=expected_variation_url)]),
    ]

    # Act
    variation_img_url = dalle3.create_variations(img_url)

    # Assert
    assert variation_img_url == expected_variation_url
    assert mock_openai_client.images.create_variation.call_count == 2


def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys):
    # Arrange
    task = "A painting of a dog"
    expected_error_message = "Error running Dalle3: API Error"

    # Mocking OpenAIError
    mock_openai_client.images.generate.side_effect = OpenAIError(
        expected_error_message, http_status=500, error="Internal Server Error"
    )

    # Act
    with pytest.raises(OpenAIError):
        dalle3(task)

    # Assert that the error message is logged
    captured = capsys.readouterr()
    assert expected_error_message in captured.err


def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client, capsys):
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    expected_error_message = "Error running Dalle3: API Error"

    # Mocking OpenAIError
    mock_openai_client.images.create_variation.side_effect = OpenAIError(
        expected_error_message, http_status=500, error="Internal Server Error"
    )

    # Act
    with pytest.raises(OpenAIError):
        dalle3.create_variations(img_url)

    # Assert that the error message is logged
    captured = capsys.readouterr()
    assert expected_error_message in captured.err


def test_dalle3_read_img_invalid_path(dalle3):
    # Arrange
    invalid_img_path = "invalid_image_path.png"

    # Act and assert
    with pytest.raises(FileNotFoundError):
        dalle3.read_img(invalid_img_path)


def test_dalle3_call_no_api_key():
    # Arrange
    task = "A painting of a dog"
    dalle3 = Dalle3(api_key=None)
    expected_error_message = "Error running Dalle3: API Key is missing"

    # Act and assert
    with pytest.raises(ValueError) as excinfo:
        dalle3(task)

    assert str(excinfo.value) == expected_error_message


def test_dalle3_create_variations_no_api_key():
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    dalle3 = Dalle3(api_key=None)
    expected_error_message = "Error running Dalle3: API Key is missing"

    # Act and assert
    with pytest.raises(ValueError) as excinfo:
        dalle3.create_variations(img_url)

    assert str(excinfo.value) == expected_error_message


def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client):
    # Arrange
    task = "A painting of a dog"

    # Simulate max retries exceeded
    mock_openai_client.images.generate.side_effect = OpenAIError(
        "Temporary error", http_status=500, error="Internal Server Error"
    )

    # Act and assert
    with pytest.raises(OpenAIError) as excinfo:
        dalle3(task)

    assert "Retry limit exceeded" in str(excinfo.value)


def test_dalle3_create_variations_with_retry_max_retries_exceeded(
    dalle3, mock_openai_client
):
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"

    # Simulate max retries exceeded
    mock_openai_client.images.create_variation.side_effect = OpenAIError(
        "Temporary error", http_status=500, error="Internal Server Error"
    )

    # Act and assert
    with pytest.raises(OpenAIError) as excinfo:
        dalle3.create_variations(img_url)

    assert "Retry limit exceeded" in str(excinfo.value)


def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
    # Arrange
    task = "A painting of a dog"
    expected_img_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    )

    # Simulate success after a retry
    mock_openai_client.images.generate.side_effect = [
        OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
        Mock(data=[Mock(url=expected_img_url)]),
    ]

    # Act
    img_url = dalle3(task)

    # Assert
    assert img_url == expected_img_url
    assert mock_openai_client.images.generate.call_count == 2


def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client):
    # Arrange
    img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
    expected_variation_url = (
        "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
    )

    # Simulate success after a retry
    mock_openai_client.images.create_variation.side_effect = [
        OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
        Mock(data=[Mock(url=expected_variation_url)]),
    ]

    # Act
    variation_img_url = dalle3.create_variations(img_url)

    # Assert
    assert variation_img_url == expected_variation_url
    assert mock_openai_client.images.create_variation.call_count == 2