From 3a9d428ebe9d0d5ba192964c84443e7095155398 Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 17 Dec 2023 01:40:50 -0500 Subject: [PATCH] [FEAT][Gigabind] --- requirements.txt | 1 + swarms/models/__init__.py | 5 +- swarms/models/base_multimodal_model.py | 6 +- swarms/models/gigabind.py | 107 +++++++++++++++ tests/models/test_gigabind.py | 183 +++++++++++++++++++++++++ 5 files changed, 296 insertions(+), 6 deletions(-) create mode 100644 swarms/models/gigabind.py create mode 100644 tests/models/test_gigabind.py diff --git a/requirements.txt b/requirements.txt index 1b3222fc..ed4c9069 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ weaviate-client==3.25.3 huggingface-hub==0.16.4 google-generativeai==0.3.1 sentencepiece==0.1.98 +requests_mock PyPDF2==3.0.1 accelerate==0.22.0 vllm diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 288d4d95..94fa6c1e 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -10,7 +10,6 @@ from swarms.models.openai_models import ( ) # noqa: E402 # from swarms.models.vllm import vLLM # noqa: E402 - # from swarms.models.zephyr import Zephyr # noqa: E402 from swarms.models.biogpt import BioGPT # noqa: E402 from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 @@ -32,7 +31,7 @@ from swarms.models.layoutlm_document_qa import ( from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402 from swarms.models.openai_tts import OpenAITTS # noqa: E402 from swarms.models.gemini import Gemini # noqa: E402 - +from swarms.models.gigabind import Gigabind # noqa: E402 # from swarms.models.gpt4v import GPT4Vision # from swarms.models.dalle3 import Dalle3 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 @@ -65,4 +64,6 @@ __all__ = [ # "vLLM", "OpenAITTS", "Gemini", + "Gigabind" ] + diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py index c31931f2..1d8ac742 100644 --- a/swarms/models/base_multimodal_model.py +++ b/swarms/models/base_multimodal_model.py @@ -1,8 +1,6 @@ import asyncio import base64 import concurrent.futures -import logging -import os import time from abc import abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -69,7 +67,7 @@ class BaseMultiModalModel: def __init__( self, - model_name: Optional[str], + model_name: Optional[str] = None, temperature: Optional[int] = 0.5, max_tokens: Optional[int] = 500, max_workers: Optional[int] = 10, @@ -100,7 +98,7 @@ class BaseMultiModalModel: @abstractmethod def run( - self, task: Optional[str], img: Optional[str], *args, **kwargs + self, task: Optional[str] = None, img: Optional[str] = None, *args, **kwargs ): """Run the model""" pass diff --git a/swarms/models/gigabind.py b/swarms/models/gigabind.py new file mode 100644 index 00000000..fa79d828 --- /dev/null +++ b/swarms/models/gigabind.py @@ -0,0 +1,107 @@ +import requests +from tenacity import retry, stop_after_attempt, wait_fixed + + + +class Gigabind: + """Gigabind API. + + Args: + host (str, optional): host. Defaults to None. + proxy_url (str, optional): proxy_url. Defaults to None. + port (int, optional): port. Defaults to 8000. + endpoint (str, optional): endpoint. Defaults to "embeddings". + + Examples: + >>> from swarms.models.gigabind import Gigabind + >>> api = Gigabind(host="localhost", port=8000, endpoint="embeddings") + >>> response = api.run(text="Hello, world!", vision="image.jpg") + >>> print(response) + """ + + def __init__( + self, + host: str = None, + proxy_url: str = None, + port: int = 8000, + endpoint: str = "embeddings", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.host = host + self.proxy_url = proxy_url + self.port = port + self.endpoint = endpoint + + # Set the URL to the API + if self.proxy_url is not None: + self.url = f"{self.proxy_url}" + else: + self.url = f"http://{host}:{port}/{endpoint}" + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) + def run( + self, + text: str = None, + vision: str = None, + audio: str = None, + *args, + **kwargs, + ): + """Run the Gigabind API. + + Args: + text (str, optional): text. Defaults to None. + vision (str, optional): images. Defaults to None. + audio (str, optional): audio file paths. Defaults to None. + + Raises: + ValueError: At least one of text, vision or audio must be provided + + Returns: + embeddings: embeddings + """ + try: + # Prepare the data to send to the API + data = {} + if text is not None: + data["text"] = text + if vision is not None: + data["vision"] = vision + if audio is not None: + data["audio"] = audio + else: + raise ValueError( + "At least one of text, vision or audio must be" + " provided" + ) + + # Send a POST request to the API and return the response + response = requests.post( + self.url, json=data, *args, **kwargs + ) + return response.json() + except Exception as error: + print(f"Gigabind API error: {error}") + return None + + def generate_summary(self, text: str = None, *args, **kwargs): + # Prepare the data to send to the API + data = {} + if text is not None: + data["text"] = text + else: + raise ValueError( + "At least one of text, vision or audio must be" + " provided" + ) + + # Send a POST request to the API and return the response + response = requests.post(self.url, json=data, *args, **kwargs) + return response.json() + + +# api = Gigabind(host="localhost", port=8000, endpoint="embeddings") +# response = api.run(text="Hello, world!", vision="image.jpg") +# print(response) diff --git a/tests/models/test_gigabind.py b/tests/models/test_gigabind.py new file mode 100644 index 00000000..3aae0739 --- /dev/null +++ b/tests/models/test_gigabind.py @@ -0,0 +1,183 @@ +import pytest +import requests + +from swarms.models.gigabind import Gigabind + +try: + import requests_mock +except ImportError: + requests_mock = None + + +@pytest.fixture +def api(): + return Gigabind( + host="localhost", port=8000, endpoint="embeddings" + ) + + +@pytest.fixture +def mock(requests_mock): + requests_mock.post( + "http://localhost:8000/embeddings", json={"result": "success"} + ) + return requests_mock + + +def test_run_with_text(api, mock): + response = api.run(text="Hello, world!") + assert response == {"result": "success"} + + +def test_run_with_vision(api, mock): + response = api.run(vision="image.jpg") + assert response == {"result": "success"} + + +def test_run_with_audio(api, mock): + response = api.run(audio="audio.mp3") + assert response == {"result": "success"} + + +def test_run_with_all(api, mock): + response = api.run( + text="Hello, world!", vision="image.jpg", audio="audio.mp3" + ) + assert response == {"result": "success"} + + +def test_run_with_none(api): + with pytest.raises(ValueError): + api.run() + + +def test_generate_summary(api, mock): + response = api.generate_summary(text="Hello, world!") + assert response == {"result": "success"} + + +def test_generate_summary_with_none(api): + with pytest.raises(ValueError): + api.generate_summary() + + +def test_retry_on_failure(api, requests_mock): + requests_mock.post( + "http://localhost:8000/embeddings", + [ + {"status_code": 500, "json": {}}, + {"status_code": 500, "json": {}}, + {"status_code": 200, "json": {"result": "success"}}, + ], + ) + response = api.run(text="Hello, world!") + assert response == {"result": "success"} + + +def test_retry_exhausted(api, requests_mock): + requests_mock.post( + "http://localhost:8000/embeddings", + [ + {"status_code": 500, "json": {}}, + {"status_code": 500, "json": {}}, + {"status_code": 500, "json": {}}, + ], + ) + response = api.run(text="Hello, world!") + assert response is None + + +def test_proxy_url(api): + api.proxy_url = "http://proxy:8080" + assert api.url == "http://proxy:8080" + + +def test_invalid_response(api, requests_mock): + requests_mock.post( + "http://localhost:8000/embeddings", text="not json" + ) + response = api.run(text="Hello, world!") + assert response is None + + +def test_connection_error(api, requests_mock): + requests_mock.post( + "http://localhost:8000/embeddings", + exc=requests.exceptions.ConnectTimeout, + ) + response = api.run(text="Hello, world!") + assert response is None + + +def test_http_error(api, requests_mock): + requests_mock.post( + "http://localhost:8000/embeddings", status_code=500 + ) + response = api.run(text="Hello, world!") + assert response is None + + +def test_url_construction(api): + assert api.url == "http://localhost:8000/embeddings" + + +def test_url_construction_with_proxy(api): + api.proxy_url = "http://proxy:8080" + assert api.url == "http://proxy:8080" + + +def test_run_with_large_text(api, mock): + large_text = "Hello, world! " * 10000 # 10,000 repetitions + response = api.run(text=large_text) + assert response == {"result": "success"} + + +def test_run_with_large_vision(api, mock): + large_vision = "image.jpg" * 10000 # 10,000 repetitions + response = api.run(vision=large_vision) + assert response == {"result": "success"} + + +def test_run_with_large_audio(api, mock): + large_audio = "audio.mp3" * 10000 # 10,000 repetitions + response = api.run(audio=large_audio) + assert response == {"result": "success"} + + +def test_run_with_large_all(api, mock): + large_text = "Hello, world! " * 10000 # 10,000 repetitions + large_vision = "image.jpg" * 10000 # 10,000 repetitions + large_audio = "audio.mp3" * 10000 # 10,000 repetitions + response = api.run( + text=large_text, vision=large_vision, audio=large_audio + ) + assert response == {"result": "success"} + + +def test_run_with_timeout(api, mock): + response = api.run(text="Hello, world!", timeout=0.001) + assert response is None + + +def test_run_with_invalid_host(api): + api.host = "invalid" + response = api.run(text="Hello, world!") + assert response is None + + +def test_run_with_invalid_port(api): + api.port = 99999 + response = api.run(text="Hello, world!") + assert response is None + + +def test_run_with_invalid_endpoint(api): + api.endpoint = "invalid" + response = api.run(text="Hello, world!") + assert response is None + + +def test_run_with_invalid_proxy_url(api): + api.proxy_url = "invalid" + response = api.run(text="Hello, world!") + assert response is None