parent
b1d3aa54a8
commit
4bef09a252
@ -0,0 +1,140 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models.base_llm import AbstractLLM
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def together_api_key_env():
|
||||||
|
"""Get the API key from the environment."""
|
||||||
|
return os.getenv("TOGETHER_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherModel(AbstractLLM):
|
||||||
|
"""
|
||||||
|
GPT-4 Vision API
|
||||||
|
|
||||||
|
This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
together_api_key : str
|
||||||
|
The OpenAI API key. Defaults to the together_api_key environment variable.
|
||||||
|
max_tokens : int
|
||||||
|
The maximum number of tokens to generate. Defaults to 300.
|
||||||
|
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
encode_image(img: str)
|
||||||
|
Encode image to base64.
|
||||||
|
run(task: str, img: str)
|
||||||
|
Run the model.
|
||||||
|
__call__(task: str, img: str)
|
||||||
|
Run the model.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
---------
|
||||||
|
>>> from swarms.models import GPT4VisionAPI
|
||||||
|
>>> llm = GPT4VisionAPI()
|
||||||
|
>>> task = "What is the color of the object?"
|
||||||
|
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg"
|
||||||
|
>>> llm.run(task, img)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
together_api_key: str = together_api_key_env,
|
||||||
|
model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
logging_enabled: bool = False,
|
||||||
|
max_workers: int = 10,
|
||||||
|
max_tokens: str = 300,
|
||||||
|
api_endpoint: str = "https://api.together.xyz",
|
||||||
|
beautify: bool = False,
|
||||||
|
streaming_enabled: Optional[bool] = False,
|
||||||
|
meta_prompt: Optional[bool] = False,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(TogetherModel).__init__(*args, **kwargs)
|
||||||
|
self.together_api_key = together_api_key
|
||||||
|
self.logging_enabled = logging_enabled
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.api_endpoint = api_endpoint
|
||||||
|
self.beautify = beautify
|
||||||
|
self.streaming_enabled = streaming_enabled
|
||||||
|
self.meta_prompt = meta_prompt
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
if self.logging_enabled:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
# Disable debug logs for requests and urllib3
|
||||||
|
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
if self.meta_prompt:
|
||||||
|
self.system_prompt = self.meta_prompt_init()
|
||||||
|
|
||||||
|
# Function to handle vision tasks
|
||||||
|
def run(self, task: str = None, *args, **kwargs):
|
||||||
|
"""Run the model."""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.together_api_key}",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [self.system_prompt],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": task,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
self.api_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = response.json()
|
||||||
|
if "choices" in out and out["choices"]:
|
||||||
|
content = (
|
||||||
|
out["choices"][0]
|
||||||
|
.get("message", {})
|
||||||
|
.get("content", None)
|
||||||
|
)
|
||||||
|
if self.streaming_enabled:
|
||||||
|
content = self.stream_response(content)
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
print("No valid response in 'choices'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
print(
|
||||||
|
f"Error with the request: {error}, make sure you"
|
||||||
|
" double check input types and positions"
|
||||||
|
)
|
||||||
|
return None
|
@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from swarms.models.together import TogetherModel
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_api_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_defaults():
|
||||||
|
model = TogetherModel()
|
||||||
|
assert model.together_api_key == "mocked-api-key"
|
||||||
|
assert model.logging_enabled is False
|
||||||
|
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
assert model.max_workers == 10
|
||||||
|
assert model.max_tokens == 300
|
||||||
|
assert model.api_endpoint == "https://api.together.xyz"
|
||||||
|
assert model.beautify is False
|
||||||
|
assert model.streaming_enabled is False
|
||||||
|
assert model.meta_prompt is False
|
||||||
|
assert model.system_prompt is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_custom_params(mock_api_key):
|
||||||
|
model = TogetherModel(
|
||||||
|
together_api_key="custom-api-key",
|
||||||
|
logging_enabled=True,
|
||||||
|
model_name="custom-model",
|
||||||
|
max_workers=5,
|
||||||
|
max_tokens=500,
|
||||||
|
api_endpoint="https://custom-api.together.xyz",
|
||||||
|
beautify=True,
|
||||||
|
streaming_enabled=True,
|
||||||
|
meta_prompt="meta-prompt",
|
||||||
|
system_prompt="system-prompt",
|
||||||
|
)
|
||||||
|
assert model.together_api_key == "custom-api-key"
|
||||||
|
assert model.logging_enabled is True
|
||||||
|
assert model.model_name == "custom-model"
|
||||||
|
assert model.max_workers == 5
|
||||||
|
assert model.max_tokens == 500
|
||||||
|
assert model.api_endpoint == "https://custom-api.together.xyz"
|
||||||
|
assert model.beautify is True
|
||||||
|
assert model.streaming_enabled is True
|
||||||
|
assert model.meta_prompt == "meta-prompt"
|
||||||
|
assert model.system_prompt == "system-prompt"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_success(mock_post, mock_api_key):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Generated response"}}]
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_failure(mock_post, mock_api_key):
|
||||||
|
mock_post.side_effect = requests.exceptions.RequestException(
|
||||||
|
"Request failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_logging_enabled(caplog, mock_api_key):
|
||||||
|
model = TogetherModel(logging_enabled=True)
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
|
||||||
|
with caplog.at_level(logging.DEBUG):
|
||||||
|
model.run(task)
|
||||||
|
|
||||||
|
assert "Sending request to" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_input", [None, 123, ["list", "of", "items"]]
|
||||||
|
)
|
||||||
|
def test_invalid_task_input(invalid_input, mock_api_key):
|
||||||
|
model = TogetherModel()
|
||||||
|
response = model.run(invalid_input)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_streaming_enabled(mock_post, mock_api_key):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Generated response"}}]
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
model = TogetherModel(streaming_enabled=True)
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response == "Generated response"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_empty_choices(mock_post, mock_api_key):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"choices": []}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.models.together_model.requests.post")
|
||||||
|
def test_run_with_exception(mock_post, mock_api_key):
|
||||||
|
mock_post.side_effect = Exception("Test exception")
|
||||||
|
|
||||||
|
model = TogetherModel()
|
||||||
|
task = "What is the color of the object?"
|
||||||
|
response = model.run(task)
|
||||||
|
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_logging_disabled(monkeypatch):
|
||||||
|
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||||
|
model = TogetherModel()
|
||||||
|
assert model.logging_enabled is False
|
||||||
|
assert not model.system_prompt
|
Loading…
Reference in new issue