[FEAT][Gigabind]

pull/317/head
Kye 1 year ago
parent 356d1dcce1
commit 3a9d428ebe

@ -23,6 +23,7 @@ weaviate-client==3.25.3
huggingface-hub==0.16.4
google-generativeai==0.3.1
sentencepiece==0.1.98
requests_mock
PyPDF2==3.0.1
accelerate==0.22.0
vllm

@ -10,7 +10,6 @@ from swarms.models.openai_models import (
) # noqa: E402
# from swarms.models.vllm import vLLM # noqa: E402
# from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
@ -32,7 +31,7 @@ from swarms.models.layoutlm_document_qa import (
from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402
from swarms.models.openai_tts import OpenAITTS # noqa: E402
from swarms.models.gemini import Gemini # noqa: E402
from swarms.models.gigabind import Gigabind # noqa: E402
# from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402
@ -65,4 +64,6 @@ __all__ = [
# "vLLM",
"OpenAITTS",
"Gemini",
"Gigabind"
]

@ -1,8 +1,6 @@
import asyncio
import base64
import concurrent.futures
import logging
import os
import time
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
@ -69,7 +67,7 @@ class BaseMultiModalModel:
def __init__(
self,
model_name: Optional[str],
model_name: Optional[str] = None,
temperature: Optional[int] = 0.5,
max_tokens: Optional[int] = 500,
max_workers: Optional[int] = 10,
@ -100,7 +98,7 @@ class BaseMultiModalModel:
@abstractmethod
def run(
self, task: Optional[str], img: Optional[str], *args, **kwargs
self, task: Optional[str] = None, img: Optional[str] = None, *args, **kwargs
):
"""Run the model"""
pass

@ -0,0 +1,107 @@
import requests
from tenacity import retry, stop_after_attempt, wait_fixed
class Gigabind:
"""Gigabind API.
Args:
host (str, optional): host. Defaults to None.
proxy_url (str, optional): proxy_url. Defaults to None.
port (int, optional): port. Defaults to 8000.
endpoint (str, optional): endpoint. Defaults to "embeddings".
Examples:
>>> from swarms.models.gigabind import Gigabind
>>> api = Gigabind(host="localhost", port=8000, endpoint="embeddings")
>>> response = api.run(text="Hello, world!", vision="image.jpg")
>>> print(response)
"""
def __init__(
self,
host: str = None,
proxy_url: str = None,
port: int = 8000,
endpoint: str = "embeddings",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.host = host
self.proxy_url = proxy_url
self.port = port
self.endpoint = endpoint
# Set the URL to the API
if self.proxy_url is not None:
self.url = f"{self.proxy_url}"
else:
self.url = f"http://{host}:{port}/{endpoint}"
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def run(
self,
text: str = None,
vision: str = None,
audio: str = None,
*args,
**kwargs,
):
"""Run the Gigabind API.
Args:
text (str, optional): text. Defaults to None.
vision (str, optional): images. Defaults to None.
audio (str, optional): audio file paths. Defaults to None.
Raises:
ValueError: At least one of text, vision or audio must be provided
Returns:
embeddings: embeddings
"""
try:
# Prepare the data to send to the API
data = {}
if text is not None:
data["text"] = text
if vision is not None:
data["vision"] = vision
if audio is not None:
data["audio"] = audio
else:
raise ValueError(
"At least one of text, vision or audio must be"
" provided"
)
# Send a POST request to the API and return the response
response = requests.post(
self.url, json=data, *args, **kwargs
)
return response.json()
except Exception as error:
print(f"Gigabind API error: {error}")
return None
def generate_summary(self, text: str = None, *args, **kwargs):
# Prepare the data to send to the API
data = {}
if text is not None:
data["text"] = text
else:
raise ValueError(
"At least one of text, vision or audio must be"
" provided"
)
# Send a POST request to the API and return the response
response = requests.post(self.url, json=data, *args, **kwargs)
return response.json()
# api = Gigabind(host="localhost", port=8000, endpoint="embeddings")
# response = api.run(text="Hello, world!", vision="image.jpg")
# print(response)

@ -0,0 +1,183 @@
import pytest
import requests
from swarms.models.gigabind import Gigabind
try:
import requests_mock
except ImportError:
requests_mock = None
@pytest.fixture
def api():
return Gigabind(
host="localhost", port=8000, endpoint="embeddings"
)
@pytest.fixture
def mock(requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings", json={"result": "success"}
)
return requests_mock
def test_run_with_text(api, mock):
response = api.run(text="Hello, world!")
assert response == {"result": "success"}
def test_run_with_vision(api, mock):
response = api.run(vision="image.jpg")
assert response == {"result": "success"}
def test_run_with_audio(api, mock):
response = api.run(audio="audio.mp3")
assert response == {"result": "success"}
def test_run_with_all(api, mock):
response = api.run(
text="Hello, world!", vision="image.jpg", audio="audio.mp3"
)
assert response == {"result": "success"}
def test_run_with_none(api):
with pytest.raises(ValueError):
api.run()
def test_generate_summary(api, mock):
response = api.generate_summary(text="Hello, world!")
assert response == {"result": "success"}
def test_generate_summary_with_none(api):
with pytest.raises(ValueError):
api.generate_summary()
def test_retry_on_failure(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
[
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
{"status_code": 200, "json": {"result": "success"}},
],
)
response = api.run(text="Hello, world!")
assert response == {"result": "success"}
def test_retry_exhausted(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
[
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
],
)
response = api.run(text="Hello, world!")
assert response is None
def test_proxy_url(api):
api.proxy_url = "http://proxy:8080"
assert api.url == "http://proxy:8080"
def test_invalid_response(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings", text="not json"
)
response = api.run(text="Hello, world!")
assert response is None
def test_connection_error(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
exc=requests.exceptions.ConnectTimeout,
)
response = api.run(text="Hello, world!")
assert response is None
def test_http_error(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings", status_code=500
)
response = api.run(text="Hello, world!")
assert response is None
def test_url_construction(api):
assert api.url == "http://localhost:8000/embeddings"
def test_url_construction_with_proxy(api):
api.proxy_url = "http://proxy:8080"
assert api.url == "http://proxy:8080"
def test_run_with_large_text(api, mock):
large_text = "Hello, world! " * 10000 # 10,000 repetitions
response = api.run(text=large_text)
assert response == {"result": "success"}
def test_run_with_large_vision(api, mock):
large_vision = "image.jpg" * 10000 # 10,000 repetitions
response = api.run(vision=large_vision)
assert response == {"result": "success"}
def test_run_with_large_audio(api, mock):
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
response = api.run(audio=large_audio)
assert response == {"result": "success"}
def test_run_with_large_all(api, mock):
large_text = "Hello, world! " * 10000 # 10,000 repetitions
large_vision = "image.jpg" * 10000 # 10,000 repetitions
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
response = api.run(
text=large_text, vision=large_vision, audio=large_audio
)
assert response == {"result": "success"}
def test_run_with_timeout(api, mock):
response = api.run(text="Hello, world!", timeout=0.001)
assert response is None
def test_run_with_invalid_host(api):
api.host = "invalid"
response = api.run(text="Hello, world!")
assert response is None
def test_run_with_invalid_port(api):
api.port = 99999
response = api.run(text="Hello, world!")
assert response is None
def test_run_with_invalid_endpoint(api):
api.endpoint = "invalid"
response = api.run(text="Hello, world!")
assert response is None
def test_run_with_invalid_proxy_url(api):
api.proxy_url = "invalid"
response = api.run(text="Hello, world!")
assert response is None
Loading…
Cancel
Save