diff --git a/swarms/utils/litellm_wrapper.py b/swarms/utils/litellm_wrapper.py index 063e6ce3..52550800 100644 --- a/swarms/utils/litellm_wrapper.py +++ b/swarms/utils/litellm_wrapper.py @@ -212,44 +212,62 @@ class LiteLLM: Process vision input specifically for Anthropic models. Handles Anthropic's specific image format requirements. """ - # Get base64 encoded image - image_url = get_image_base64(image) - - # Extract mime type from the data URI or use default - mime_type = "image/jpeg" # default - if "data:" in image_url and ";base64," in image_url: - mime_type = image_url.split(";base64,")[0].split("data:")[ - 1 - ] - - # Ensure mime type is one of the supported formats - supported_formats = [ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", - ] - if mime_type not in supported_formats: - mime_type = ( - "image/jpeg" # fallback to jpeg if unsupported + # Check if we can use direct URL + if self._should_use_direct_url(image): + # Use direct URL without base64 conversion + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": task}, + { + "type": "image_url", + "image_url": { + "url": image, + }, + }, + ], + } ) + else: + # Fall back to base64 conversion for local files + image_url = get_image_base64(image) + + # Extract mime type from the data URI or use default + mime_type = "image/jpeg" # default + if "data:" in image_url and ";base64," in image_url: + mime_type = image_url.split(";base64,")[0].split("data:")[ + 1 + ] + + # Ensure mime type is one of the supported formats + supported_formats = [ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + ] + if mime_type not in supported_formats: + mime_type = ( + "image/jpeg" # fallback to jpeg if unsupported + ) - # Construct Anthropic vision message - messages.append( - { - "role": "user", - "content": [ - {"type": "text", "text": task}, - { - "type": "image_url", - "image_url": { - "url": image_url, - "format": mime_type, + # Construct Anthropic vision message with base64 + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": task}, + { + "type": "image_url", + "image_url": { + "url": image_url, + "format": mime_type, + }, }, - }, - ], - } - ) + ], + } + ) return messages @@ -260,21 +278,29 @@ class LiteLLM: Process vision input specifically for OpenAI models. Handles OpenAI's specific image format requirements. """ - # Get base64 encoded image with proper format - image_url = get_image_base64(image) - - # Prepare vision message - vision_message = { - "type": "image_url", - "image_url": {"url": image_url}, - } - - # Add format for specific models - extension = Path(image).suffix.lower() - mime_type = ( - f"image/{extension[1:]}" if extension else "image/jpeg" - ) - vision_message["image_url"]["format"] = mime_type + # Check if we can use direct URL + if self._should_use_direct_url(image): + # Use direct URL without base64 conversion + vision_message = { + "type": "image_url", + "image_url": {"url": image}, + } + else: + # Fall back to base64 conversion for local files + image_url = get_image_base64(image) + + # Prepare vision message with base64 + vision_message = { + "type": "image_url", + "image_url": {"url": image_url}, + } + + # Add format for specific models + extension = Path(image).suffix.lower() + mime_type = ( + f"image/{extension[1:]}" if extension else "image/jpeg" + ) + vision_message["image_url"]["format"] = mime_type # Append vision message messages.append( @@ -289,44 +315,61 @@ class LiteLLM: return messages + def _should_use_direct_url(self, image: str) -> bool: + """ + Determine if we should use direct URL passing instead of base64 conversion. + + Args: + image (str): The image source (URL or file path) + + Returns: + bool: True if we should use direct URL, False if we need base64 conversion + """ + # Only use direct URL for HTTP/HTTPS URLs + if not image.startswith(("http://", "https://")): + return False + + # Check for local/custom models that might not support direct URLs + model_lower = self.model_name.lower() + local_indicators = ["localhost", "127.0.0.1", "local", "custom", "ollama", "llama-cpp"] + + is_local = any(indicator in model_lower for indicator in local_indicators) or \ + (self.base_url is not None and any(indicator in self.base_url.lower() for indicator in local_indicators)) + + if is_local: + return False + + # Use LiteLLM's supports_vision to check if model supports vision and direct URLs + try: + return supports_vision(model=self.model_name) + except Exception: + return False + def vision_processing( self, task: str, image: str, messages: Optional[list] = None ): """ Process the image for the given task. Handles different image formats and model requirements. + + This method now intelligently chooses between: + 1. Direct URL passing (when model supports it and image is a URL) + 2. Base64 conversion (for local files or unsupported models) + + This approach reduces server load and improves performance by avoiding + unnecessary image downloads and base64 conversions when possible. """ - # # # Handle Anthropic models separately - # # if "anthropic" in self.model_name.lower() or "claude" in self.model_name.lower(): - # # messages = self.anthropic_vision_processing(task, image, messages) - # # return messages - - # # Get base64 encoded image with proper format - # image_url = get_image_base64(image) - - # # Prepare vision message - # vision_message = { - # "type": "image_url", - # "image_url": {"url": image_url}, - # } - - # # Add format for specific models - # extension = Path(image).suffix.lower() - # mime_type = f"image/{extension[1:]}" if extension else "image/jpeg" - # vision_message["image_url"]["format"] = mime_type - - # # Append vision message - # messages.append( - # { - # "role": "user", - # "content": [ - # {"type": "text", "text": task}, - # vision_message, - # ], - # } - # ) - - # return messages + logger.info(f"Processing image for model: {self.model_name}") + + # Log whether we're using direct URL or base64 conversion + if self._should_use_direct_url(image): + logger.info(f"Using direct URL passing for image: {image[:100]}...") + else: + if image.startswith(("http://", "https://")): + logger.info("Converting URL image to base64 (model doesn't support direct URLs)") + else: + logger.info("Converting local file to base64") + if ( "anthropic" in self.model_name.lower() or "claude" in self.model_name.lower() @@ -370,7 +413,16 @@ class LiteLLM: def check_if_model_supports_vision(self, img: str = None): """ - Check if the model supports vision. + Check if the model supports vision capabilities. + + This method uses LiteLLM's built-in supports_vision function to verify + that the model can handle image inputs before processing. + + Args: + img (str, optional): Image path/URL to validate against model capabilities + + Raises: + ValueError: If the model doesn't support vision and an image is provided """ if img is not None: out = supports_vision(model=self.model_name) diff --git a/tests/utils/test_litellm_wrapper.py b/tests/utils/test_litellm_wrapper.py index 02e79c9f..3a657bae 100644 --- a/tests/utils/test_litellm_wrapper.py +++ b/tests/utils/test_litellm_wrapper.py @@ -201,6 +201,119 @@ def run_test_suite(): except Exception as e: log_test_result("Batched Run", False, str(e)) + # Test 8: Vision Support Check + try: + logger.info("Testing vision support check") + llm = LiteLLM(model_name="gpt-4o") + # This should not raise an error for vision-capable models + llm.check_if_model_supports_vision(img="test.jpg") + log_test_result("Vision Support Check", True) + except Exception as e: + log_test_result("Vision Support Check", False, str(e)) + + # Test 9: Direct URL Processing + try: + logger.info("Testing direct URL processing") + llm = LiteLLM(model_name="gpt-4o") + test_url = "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true" + should_use_direct = llm._should_use_direct_url(test_url) + assert isinstance(should_use_direct, bool) + log_test_result("Direct URL Processing", True) + except Exception as e: + log_test_result("Direct URL Processing", False, str(e)) + + # Test 10: Message Preparation with Image + try: + logger.info("Testing message preparation with image") + llm = LiteLLM(model_name="gpt-4o") + # Mock image URL to test message structure + test_img = "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true" + messages = llm._prepare_messages("Describe this image", img=test_img) + assert isinstance(messages, list) + assert len(messages) >= 1 + # Check if image content is properly structured + user_message = next((msg for msg in messages if msg["role"] == "user"), None) + assert user_message is not None + log_test_result("Message Preparation with Image", True) + except Exception as e: + log_test_result("Message Preparation with Image", False, str(e)) + + # Test 11: Vision Processing Methods + try: + logger.info("Testing vision processing methods") + llm = LiteLLM(model_name="gpt-4o") + messages = [] + + # Test OpenAI vision processing + processed_messages = llm.openai_vision_processing( + "Describe this image", + "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true", + messages.copy() + ) + assert isinstance(processed_messages, list) + assert len(processed_messages) > 0 + + # Test Anthropic vision processing + llm_anthropic = LiteLLM(model_name="claude-3-5-sonnet-20241022") + processed_messages_anthropic = llm_anthropic.anthropic_vision_processing( + "Describe this image", + "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true", + messages.copy() + ) + assert isinstance(processed_messages_anthropic, list) + assert len(processed_messages_anthropic) > 0 + + log_test_result("Vision Processing Methods", True) + except Exception as e: + log_test_result("Vision Processing Methods", False, str(e)) + + # Test 12: Local vs URL Detection + try: + logger.info("Testing local vs URL detection") + llm = LiteLLM(model_name="gpt-4o") + + # Test URL detection + url_test = "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true" + is_url_direct = llm._should_use_direct_url(url_test) + + # Test local file detection + local_test = "/path/to/local/image.jpg" + is_local_direct = llm._should_use_direct_url(local_test) + + # URLs should potentially use direct, local files should not + assert isinstance(is_url_direct, bool) + assert isinstance(is_local_direct, bool) + assert is_local_direct == False # Local files should never use direct URL + + log_test_result("Local vs URL Detection", True) + except Exception as e: + log_test_result("Local vs URL Detection", False, str(e)) + + # Test 13: Vision Message Structure + try: + logger.info("Testing vision message structure") + llm = LiteLLM(model_name="gpt-4o") + messages = [] + + # Test message structure for image input + result = llm.vision_processing( + task="What do you see?", + image="https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true", + messages=messages + ) + + assert isinstance(result, list) + assert len(result) > 0 + + # Verify the message contains both text and image components + user_msg = result[-1] # Last message should be user message + assert user_msg["role"] == "user" + assert "content" in user_msg + + log_test_result("Vision Message Structure", True) + except Exception as e: + log_test_result("Vision Message Structure", False, str(e)) + # Generate test report success_rate = (passed_tests / total_tests) * 100 logger.info("\n=== Test Suite Report ===")