parent
356d1dcce1
commit
3a9d428ebe
@ -0,0 +1,107 @@
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
|
||||
|
||||
class Gigabind:
|
||||
"""Gigabind API.
|
||||
|
||||
Args:
|
||||
host (str, optional): host. Defaults to None.
|
||||
proxy_url (str, optional): proxy_url. Defaults to None.
|
||||
port (int, optional): port. Defaults to 8000.
|
||||
endpoint (str, optional): endpoint. Defaults to "embeddings".
|
||||
|
||||
Examples:
|
||||
>>> from swarms.models.gigabind import Gigabind
|
||||
>>> api = Gigabind(host="localhost", port=8000, endpoint="embeddings")
|
||||
>>> response = api.run(text="Hello, world!", vision="image.jpg")
|
||||
>>> print(response)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = None,
|
||||
proxy_url: str = None,
|
||||
port: int = 8000,
|
||||
endpoint: str = "embeddings",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.host = host
|
||||
self.proxy_url = proxy_url
|
||||
self.port = port
|
||||
self.endpoint = endpoint
|
||||
|
||||
# Set the URL to the API
|
||||
if self.proxy_url is not None:
|
||||
self.url = f"{self.proxy_url}"
|
||||
else:
|
||||
self.url = f"http://{host}:{port}/{endpoint}"
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
||||
def run(
|
||||
self,
|
||||
text: str = None,
|
||||
vision: str = None,
|
||||
audio: str = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run the Gigabind API.
|
||||
|
||||
Args:
|
||||
text (str, optional): text. Defaults to None.
|
||||
vision (str, optional): images. Defaults to None.
|
||||
audio (str, optional): audio file paths. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: At least one of text, vision or audio must be provided
|
||||
|
||||
Returns:
|
||||
embeddings: embeddings
|
||||
"""
|
||||
try:
|
||||
# Prepare the data to send to the API
|
||||
data = {}
|
||||
if text is not None:
|
||||
data["text"] = text
|
||||
if vision is not None:
|
||||
data["vision"] = vision
|
||||
if audio is not None:
|
||||
data["audio"] = audio
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of text, vision or audio must be"
|
||||
" provided"
|
||||
)
|
||||
|
||||
# Send a POST request to the API and return the response
|
||||
response = requests.post(
|
||||
self.url, json=data, *args, **kwargs
|
||||
)
|
||||
return response.json()
|
||||
except Exception as error:
|
||||
print(f"Gigabind API error: {error}")
|
||||
return None
|
||||
|
||||
def generate_summary(self, text: str = None, *args, **kwargs):
|
||||
# Prepare the data to send to the API
|
||||
data = {}
|
||||
if text is not None:
|
||||
data["text"] = text
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of text, vision or audio must be"
|
||||
" provided"
|
||||
)
|
||||
|
||||
# Send a POST request to the API and return the response
|
||||
response = requests.post(self.url, json=data, *args, **kwargs)
|
||||
return response.json()
|
||||
|
||||
|
||||
# api = Gigabind(host="localhost", port=8000, endpoint="embeddings")
|
||||
# response = api.run(text="Hello, world!", vision="image.jpg")
|
||||
# print(response)
|
@ -0,0 +1,183 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from swarms.models.gigabind import Gigabind
|
||||
|
||||
try:
|
||||
import requests_mock
|
||||
except ImportError:
|
||||
requests_mock = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api():
|
||||
return Gigabind(
|
||||
host="localhost", port=8000, endpoint="embeddings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock(requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings", json={"result": "success"}
|
||||
)
|
||||
return requests_mock
|
||||
|
||||
|
||||
def test_run_with_text(api, mock):
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_vision(api, mock):
|
||||
response = api.run(vision="image.jpg")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_audio(api, mock):
|
||||
response = api.run(audio="audio.mp3")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_all(api, mock):
|
||||
response = api.run(
|
||||
text="Hello, world!", vision="image.jpg", audio="audio.mp3"
|
||||
)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_none(api):
|
||||
with pytest.raises(ValueError):
|
||||
api.run()
|
||||
|
||||
|
||||
def test_generate_summary(api, mock):
|
||||
response = api.generate_summary(text="Hello, world!")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_generate_summary_with_none(api):
|
||||
with pytest.raises(ValueError):
|
||||
api.generate_summary()
|
||||
|
||||
|
||||
def test_retry_on_failure(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings",
|
||||
[
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 200, "json": {"result": "success"}},
|
||||
],
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_retry_exhausted(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings",
|
||||
[
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 500, "json": {}},
|
||||
{"status_code": 500, "json": {}},
|
||||
],
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_proxy_url(api):
|
||||
api.proxy_url = "http://proxy:8080"
|
||||
assert api.url == "http://proxy:8080"
|
||||
|
||||
|
||||
def test_invalid_response(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings", text="not json"
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_connection_error(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings",
|
||||
exc=requests.exceptions.ConnectTimeout,
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_http_error(api, requests_mock):
|
||||
requests_mock.post(
|
||||
"http://localhost:8000/embeddings", status_code=500
|
||||
)
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_url_construction(api):
|
||||
assert api.url == "http://localhost:8000/embeddings"
|
||||
|
||||
|
||||
def test_url_construction_with_proxy(api):
|
||||
api.proxy_url = "http://proxy:8080"
|
||||
assert api.url == "http://proxy:8080"
|
||||
|
||||
|
||||
def test_run_with_large_text(api, mock):
|
||||
large_text = "Hello, world! " * 10000 # 10,000 repetitions
|
||||
response = api.run(text=large_text)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_large_vision(api, mock):
|
||||
large_vision = "image.jpg" * 10000 # 10,000 repetitions
|
||||
response = api.run(vision=large_vision)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_large_audio(api, mock):
|
||||
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
|
||||
response = api.run(audio=large_audio)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_large_all(api, mock):
|
||||
large_text = "Hello, world! " * 10000 # 10,000 repetitions
|
||||
large_vision = "image.jpg" * 10000 # 10,000 repetitions
|
||||
large_audio = "audio.mp3" * 10000 # 10,000 repetitions
|
||||
response = api.run(
|
||||
text=large_text, vision=large_vision, audio=large_audio
|
||||
)
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
def test_run_with_timeout(api, mock):
|
||||
response = api.run(text="Hello, world!", timeout=0.001)
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_host(api):
|
||||
api.host = "invalid"
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_port(api):
|
||||
api.port = 99999
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_endpoint(api):
|
||||
api.endpoint = "invalid"
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_invalid_proxy_url(api):
|
||||
api.proxy_url = "invalid"
|
||||
response = api.run(text="Hello, world!")
|
||||
assert response is None
|
Loading…
Reference in new issue