parent
b1d3aa54a8
commit
4bef09a252
@ -0,0 +1,140 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models.base_llm import AbstractLLM
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def together_api_key_env():
|
||||
"""Get the API key from the environment."""
|
||||
return os.getenv("TOGETHER_API_KEY")
|
||||
|
||||
|
||||
class TogetherModel(AbstractLLM):
|
||||
"""
|
||||
GPT-4 Vision API
|
||||
|
||||
This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
together_api_key : str
|
||||
The OpenAI API key. Defaults to the together_api_key environment variable.
|
||||
max_tokens : int
|
||||
The maximum number of tokens to generate. Defaults to 300.
|
||||
|
||||
|
||||
Methods
|
||||
-------
|
||||
encode_image(img: str)
|
||||
Encode image to base64.
|
||||
run(task: str, img: str)
|
||||
Run the model.
|
||||
__call__(task: str, img: str)
|
||||
Run the model.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
>>> from swarms.models import GPT4VisionAPI
|
||||
>>> llm = GPT4VisionAPI()
|
||||
>>> task = "What is the color of the object?"
|
||||
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg"
|
||||
>>> llm.run(task, img)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
together_api_key: str = together_api_key_env,
|
||||
model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
logging_enabled: bool = False,
|
||||
max_workers: int = 10,
|
||||
max_tokens: str = 300,
|
||||
api_endpoint: str = "https://api.together.xyz",
|
||||
beautify: bool = False,
|
||||
streaming_enabled: Optional[bool] = False,
|
||||
meta_prompt: Optional[bool] = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super(TogetherModel).__init__(*args, **kwargs)
|
||||
self.together_api_key = together_api_key
|
||||
self.logging_enabled = logging_enabled
|
||||
self.model_name = model_name
|
||||
self.max_workers = max_workers
|
||||
self.max_tokens = max_tokens
|
||||
self.api_endpoint = api_endpoint
|
||||
self.beautify = beautify
|
||||
self.streaming_enabled = streaming_enabled
|
||||
self.meta_prompt = meta_prompt
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
if self.logging_enabled:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
# Disable debug logs for requests and urllib3
|
||||
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
if self.meta_prompt:
|
||||
self.system_prompt = self.meta_prompt_init()
|
||||
|
||||
# Function to handle vision tasks
|
||||
def run(self, task: str = None, *args, **kwargs):
|
||||
"""Run the model."""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.together_api_key}",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [self.system_prompt],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": task,
|
||||
},
|
||||
],
|
||||
"max_tokens": self.max_tokens,
|
||||
**kwargs,
|
||||
}
|
||||
response = requests.post(
|
||||
self.api_endpoint,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
out = response.json()
|
||||
if "choices" in out and out["choices"]:
|
||||
content = (
|
||||
out["choices"][0]
|
||||
.get("message", {})
|
||||
.get("content", None)
|
||||
)
|
||||
if self.streaming_enabled:
|
||||
content = self.stream_response(content)
|
||||
return content
|
||||
else:
|
||||
print("No valid response in 'choices'")
|
||||
return None
|
||||
|
||||
except Exception as error:
|
||||
print(
|
||||
f"Error with the request: {error}, make sure you"
|
||||
" double check input types and positions"
|
||||
)
|
||||
return None
|
@ -0,0 +1,144 @@
|
||||
import os
|
||||
import requests
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from swarms.models.together import TogetherModel
|
||||
import logging
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_key(monkeypatch):
|
||||
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||
|
||||
|
||||
def test_init_defaults():
|
||||
model = TogetherModel()
|
||||
assert model.together_api_key == "mocked-api-key"
|
||||
assert model.logging_enabled is False
|
||||
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
assert model.max_workers == 10
|
||||
assert model.max_tokens == 300
|
||||
assert model.api_endpoint == "https://api.together.xyz"
|
||||
assert model.beautify is False
|
||||
assert model.streaming_enabled is False
|
||||
assert model.meta_prompt is False
|
||||
assert model.system_prompt is None
|
||||
|
||||
|
||||
def test_init_custom_params(mock_api_key):
|
||||
model = TogetherModel(
|
||||
together_api_key="custom-api-key",
|
||||
logging_enabled=True,
|
||||
model_name="custom-model",
|
||||
max_workers=5,
|
||||
max_tokens=500,
|
||||
api_endpoint="https://custom-api.together.xyz",
|
||||
beautify=True,
|
||||
streaming_enabled=True,
|
||||
meta_prompt="meta-prompt",
|
||||
system_prompt="system-prompt",
|
||||
)
|
||||
assert model.together_api_key == "custom-api-key"
|
||||
assert model.logging_enabled is True
|
||||
assert model.model_name == "custom-model"
|
||||
assert model.max_workers == 5
|
||||
assert model.max_tokens == 500
|
||||
assert model.api_endpoint == "https://custom-api.together.xyz"
|
||||
assert model.beautify is True
|
||||
assert model.streaming_enabled is True
|
||||
assert model.meta_prompt == "meta-prompt"
|
||||
assert model.system_prompt == "system-prompt"
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_success(mock_post, mock_api_key):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Generated response"}}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = TogetherModel()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response == "Generated response"
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_failure(mock_post, mock_api_key):
|
||||
mock_post.side_effect = requests.exceptions.RequestException(
|
||||
"Request failed"
|
||||
)
|
||||
|
||||
model = TogetherModel()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_logging_enabled(caplog, mock_api_key):
|
||||
model = TogetherModel(logging_enabled=True)
|
||||
task = "What is the color of the object?"
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
model.run(task)
|
||||
|
||||
assert "Sending request to" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_input", [None, 123, ["list", "of", "items"]]
|
||||
)
|
||||
def test_invalid_task_input(invalid_input, mock_api_key):
|
||||
model = TogetherModel()
|
||||
response = model.run(invalid_input)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_streaming_enabled(mock_post, mock_api_key):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Generated response"}}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = TogetherModel(streaming_enabled=True)
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response == "Generated response"
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_empty_choices(mock_post, mock_api_key):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"choices": []}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = TogetherModel()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_with_exception(mock_post, mock_api_key):
|
||||
mock_post.side_effect = Exception("Test exception")
|
||||
|
||||
model = TogetherModel()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_init_logging_disabled(monkeypatch):
|
||||
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||
model = TogetherModel()
|
||||
assert model.logging_enabled is False
|
||||
assert not model.system_prompt
|
Loading…
Reference in new issue