[FEAT][Gigabind]

pull/317/head
Kye 1 year ago
parent 356d1dcce1
commit 3a9d428ebe

@ -23,6 +23,7 @@ weaviate-client==3.25.3
huggingface-hub==0.16.4 huggingface-hub==0.16.4
google-generativeai==0.3.1 google-generativeai==0.3.1
sentencepiece==0.1.98 sentencepiece==0.1.98
requests_mock
PyPDF2==3.0.1 PyPDF2==3.0.1
accelerate==0.22.0 accelerate==0.22.0
vllm vllm

@ -10,7 +10,6 @@ from swarms.models.openai_models import (
) # noqa: E402 ) # noqa: E402
# from swarms.models.vllm import vLLM # noqa: E402 # from swarms.models.vllm import vLLM # noqa: E402
# from swarms.models.zephyr import Zephyr # noqa: E402 # from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402 from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
@ -32,7 +31,7 @@ from swarms.models.layoutlm_document_qa import (
from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402 from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402
from swarms.models.openai_tts import OpenAITTS # noqa: E402 from swarms.models.openai_tts import OpenAITTS # noqa: E402
from swarms.models.gemini import Gemini # noqa: E402 from swarms.models.gemini import Gemini # noqa: E402
from swarms.models.gigabind import Gigabind # noqa: E402
# from swarms.models.gpt4v import GPT4Vision # from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3 # from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402
@ -65,4 +64,6 @@ __all__ = [
# "vLLM", # "vLLM",
"OpenAITTS", "OpenAITTS",
"Gemini", "Gemini",
"Gigabind"
] ]

@ -1,8 +1,6 @@
import asyncio import asyncio
import base64 import base64
import concurrent.futures import concurrent.futures
import logging
import os
import time import time
from abc import abstractmethod from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -69,7 +67,7 @@ class BaseMultiModalModel:
def __init__( def __init__(
self, self,
model_name: Optional[str], model_name: Optional[str] = None,
temperature: Optional[int] = 0.5, temperature: Optional[int] = 0.5,
max_tokens: Optional[int] = 500, max_tokens: Optional[int] = 500,
max_workers: Optional[int] = 10, max_workers: Optional[int] = 10,
@ -100,7 +98,7 @@ class BaseMultiModalModel:
@abstractmethod @abstractmethod
def run( def run(
self, task: Optional[str], img: Optional[str], *args, **kwargs self, task: Optional[str] = None, img: Optional[str] = None, *args, **kwargs
): ):
"""Run the model""" """Run the model"""
pass pass

@ -0,0 +1,107 @@
import requests
from tenacity import retry, stop_after_attempt, wait_fixed
class Gigabind:
"""Gigabind API.
Args:
host (str, optional): host. Defaults to None.
proxy_url (str, optional): proxy_url. Defaults to None.
port (int, optional): port. Defaults to 8000.
endpoint (str, optional): endpoint. Defaults to "embeddings".
Examples:
>>> from swarms.models.gigabind import Gigabind
>>> api = Gigabind(host="localhost", port=8000, endpoint="embeddings")
>>> response = api.run(text="Hello, world!", vision="image.jpg")
>>> print(response)
"""
def __init__(
self,
host: str = None,
proxy_url: str = None,
port: int = 8000,
endpoint: str = "embeddings",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.host = host
self.proxy_url = proxy_url
self.port = port
self.endpoint = endpoint
# Set the URL to the API
if self.proxy_url is not None:
self.url = f"{self.proxy_url}"
else:
self.url = f"http://{host}:{port}/{endpoint}"
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def run(
self,
text: str = None,
vision: str = None,
audio: str = None,
*args,
**kwargs,
):
"""Run the Gigabind API.
Args:
text (str, optional): text. Defaults to None.
vision (str, optional): images. Defaults to None.
audio (str, optional): audio file paths. Defaults to None.
Raises:
ValueError: At least one of text, vision or audio must be provided
Returns:
embeddings: embeddings
"""
try:
# Prepare the data to send to the API
data = {}
if text is not None:
data["text"] = text
if vision is not None:
data["vision"] = vision
if audio is not None:
data["audio"] = audio
else:
raise ValueError(
"At least one of text, vision or audio must be"
" provided"
)
# Send a POST request to the API and return the response
response = requests.post(
self.url, json=data, *args, **kwargs
)
return response.json()
except Exception as error:
print(f"Gigabind API error: {error}")
return None
def generate_summary(self, text: str = None, *args, **kwargs):
# Prepare the data to send to the API
data = {}
if text is not None:
data["text"] = text
else:
raise ValueError(
"At least one of text, vision or audio must be"
" provided"
)
# Send a POST request to the API and return the response
response = requests.post(self.url, json=data, *args, **kwargs)
return response.json()
# api = Gigabind(host="localhost", port=8000, endpoint="embeddings")
# response = api.run(text="Hello, world!", vision="image.jpg")
# print(response)

@ -0,0 +1,183 @@
import pytest
import requests
from swarms.models.gigabind import Gigabind
try:
import requests_mock
except ImportError:
requests_mock = None
@pytest.fixture
def api():
return Gigabind(
host="localhost", port=8000, endpoint="embeddings"
)
@pytest.fixture
def mock(requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings", json={"result": "success"}
)
return requests_mock
def test_run_with_text(api, mock):
response = api.run(text="Hello, world!")
assert response == {"result": "success"}
def test_run_with_vision(api, mock):
response = api.run(vision="image.jpg")
assert response == {"result": "success"}
def test_run_with_audio(api, mock):
response = api.run(audio="audio.mp3")
assert response == {"result": "success"}
def test_run_with_all(api, mock):
response = api.run(
text="Hello, world!", vision="image.jpg", audio="audio.mp3"
)
assert response == {"result": "success"}
def test_run_with_none(api):
with pytest.raises(ValueError):
api.run()
def test_generate_summary(api, mock):
response = api.generate_summary(text="Hello, world!")
assert response == {"result": "success"}
def test_generate_summary_with_none(api):
with pytest.raises(ValueError):
api.generate_summary()
def test_retry_on_failure(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
[
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
{"status_code": 200, "json": {"result": "success"}},
],
)
response = api.run(text="Hello, world!")
assert response == {"result": "success"}
def test_retry_exhausted(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
[
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
{"status_code": 500, "json": {}},
],
)
response = api.run(text="Hello, world!")
assert response is None
def test_proxy_url(api):
api.proxy_url = "http://proxy:8080"
assert api.url == "http://proxy:8080"
def test_invalid_response(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings", text="not json"
)
response = api.run(text="Hello, world!")
assert response is None
def test_connection_error(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings",
exc=requests.exceptions.ConnectTimeout,
)
response = api.run(text="Hello, world!")
assert response is None
def test_http_error(api, requests_mock):
requests_mock.post(
"http://localhost:8000/embeddings", status_code=500
)
response = api.run(text="Hello, world!")
assert response is None
def test_url_construction(api):
assert api.url == "http://localhost:8000/embeddings"
def test_url_construction_with_proxy(api):
api.proxy_url = "http://proxy:8080"
assert api.url == "http://proxy:8080"
def test_run_with_large_text(api, mock):
large_text = "Hello, world! " * 10000 # 10,000 repetitions
response = api.run(text=large_text)
assert response == {"result": "success"}
def test_run_with_large_vision(api, mock):
large_vision = "image.jpg" * 10000 # 10,000 repetitions
response = api.run(vision=large_vision)
assert response == {"result": "success"}
def test_run_with_large_audio(api, mock):
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
response = api.run(audio=large_audio)
assert response == {"result": "success"}
def test_run_with_large_all(api, mock):
large_text = "Hello, world! " * 10000 # 10,000 repetitions
large_vision = "image.jpg" * 10000 # 10,000 repetitions
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
response = api.run(
text=large_text, vision=large_vision, audio=large_audio
)
assert response == {"result": "success"}
def test_run_with_timeout(api, mock):
response = api.run(text="Hello, world!", timeout=0.001)
assert response is None
def test_run_with_invalid_host(api):
api.host = "invalid"
response = api.run(text="Hello, world!")
assert response is None
def test_run_with_invalid_port(api):
api.port = 99999
response = api.run(text="Hello, world!")
assert response is None
def test_run_with_invalid_endpoint(api):
api.endpoint = "invalid"
response = api.run(text="Hello, world!")
assert response is None
def test_run_with_invalid_proxy_url(api):
api.proxy_url = "invalid"
response = api.run(text="Hello, world!")
assert response is None
Loading…
Cancel
Save