[FEAT][TogertherModel]

pull/294/head^2
Kye 1 year ago
parent b1d3aa54a8
commit 4bef09a252

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "2.7.9" version = "2.8.0"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

@ -0,0 +1,140 @@
import logging
import os
from typing import Optional
import requests
from dotenv import load_dotenv
from swarms.models.base_llm import AbstractLLM
# Load environment variables
load_dotenv()
def together_api_key_env():
"""Get the API key from the environment."""
return os.getenv("TOGETHER_API_KEY")
class TogetherModel(AbstractLLM):
"""
GPT-4 Vision API
This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model.
Parameters
----------
together_api_key : str
The OpenAI API key. Defaults to the together_api_key environment variable.
max_tokens : int
The maximum number of tokens to generate. Defaults to 300.
Methods
-------
encode_image(img: str)
Encode image to base64.
run(task: str, img: str)
Run the model.
__call__(task: str, img: str)
Run the model.
Examples:
---------
>>> from swarms.models import GPT4VisionAPI
>>> llm = GPT4VisionAPI()
>>> task = "What is the color of the object?"
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg"
>>> llm.run(task, img)
"""
def __init__(
self,
together_api_key: str = together_api_key_env,
model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
logging_enabled: bool = False,
max_workers: int = 10,
max_tokens: str = 300,
api_endpoint: str = "https://api.together.xyz",
beautify: bool = False,
streaming_enabled: Optional[bool] = False,
meta_prompt: Optional[bool] = False,
system_prompt: Optional[str] = None,
*args,
**kwargs,
):
super(TogetherModel).__init__(*args, **kwargs)
self.together_api_key = together_api_key
self.logging_enabled = logging_enabled
self.model_name = model_name
self.max_workers = max_workers
self.max_tokens = max_tokens
self.api_endpoint = api_endpoint
self.beautify = beautify
self.streaming_enabled = streaming_enabled
self.meta_prompt = meta_prompt
self.system_prompt = system_prompt
if self.logging_enabled:
logging.basicConfig(level=logging.DEBUG)
else:
# Disable debug logs for requests and urllib3
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
if self.meta_prompt:
self.system_prompt = self.meta_prompt_init()
# Function to handle vision tasks
def run(self, task: str = None, *args, **kwargs):
"""Run the model."""
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.together_api_key}",
}
payload = {
"model": self.model_name,
"messages": [
{
"role": "system",
"content": [self.system_prompt],
},
{
"role": "user",
"content": task,
},
],
"max_tokens": self.max_tokens,
**kwargs,
}
response = requests.post(
self.api_endpoint,
headers=headers,
json=payload,
*args,
**kwargs,
)
out = response.json()
if "choices" in out and out["choices"]:
content = (
out["choices"][0]
.get("message", {})
.get("content", None)
)
if self.streaming_enabled:
content = self.stream_response(content)
return content
else:
print("No valid response in 'choices'")
return None
except Exception as error:
print(
f"Error with the request: {error}, make sure you"
" double check input types and positions"
)
return None

@ -1,6 +1,5 @@
def react_prompt(task: str = None): def react_prompt(task: str = None):
REACT = f""" PROMPT = f"""
Task Description: Task Description:
Accomplish the following {task} using the reasoning guidelines below. Accomplish the following {task} using the reasoning guidelines below.
@ -56,4 +55,4 @@ def react_prompt(task: str = None):
Remember, your goal is to provide a transparent and logical process that leads from observation to effective action. Your responses should demonstrate clear thinking, an understanding of the problem, and a rational approach to solving it. The use of tokens helps to structure your response and clarify the different stages of your reasoning and action. Remember, your goal is to provide a transparent and logical process that leads from observation to effective action. Your responses should demonstrate clear thinking, an understanding of the problem, and a rational approach to solving it. The use of tokens helps to structure your response and clarify the different stages of your reasoning and action.
""" """
return REACT return PROMPT

@ -0,0 +1,144 @@
import os
import requests
import pytest
from unittest.mock import patch, Mock
from swarms.models.together import TogetherModel
import logging
@pytest.fixture
def mock_api_key(monkeypatch):
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
def test_init_defaults():
model = TogetherModel()
assert model.together_api_key == "mocked-api-key"
assert model.logging_enabled is False
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
assert model.max_workers == 10
assert model.max_tokens == 300
assert model.api_endpoint == "https://api.together.xyz"
assert model.beautify is False
assert model.streaming_enabled is False
assert model.meta_prompt is False
assert model.system_prompt is None
def test_init_custom_params(mock_api_key):
model = TogetherModel(
together_api_key="custom-api-key",
logging_enabled=True,
model_name="custom-model",
max_workers=5,
max_tokens=500,
api_endpoint="https://custom-api.together.xyz",
beautify=True,
streaming_enabled=True,
meta_prompt="meta-prompt",
system_prompt="system-prompt",
)
assert model.together_api_key == "custom-api-key"
assert model.logging_enabled is True
assert model.model_name == "custom-model"
assert model.max_workers == 5
assert model.max_tokens == 500
assert model.api_endpoint == "https://custom-api.together.xyz"
assert model.beautify is True
assert model.streaming_enabled is True
assert model.meta_prompt == "meta-prompt"
assert model.system_prompt == "system-prompt"
@patch("swarms.models.together_model.requests.post")
def test_run_success(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Generated response"}}]
}
mock_post.return_value = mock_response
model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)
assert response == "Generated response"
@patch("swarms.models.together_model.requests.post")
def test_run_failure(mock_post, mock_api_key):
mock_post.side_effect = requests.exceptions.RequestException(
"Request failed"
)
model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)
assert response is None
def test_run_with_logging_enabled(caplog, mock_api_key):
model = TogetherModel(logging_enabled=True)
task = "What is the color of the object?"
with caplog.at_level(logging.DEBUG):
model.run(task)
assert "Sending request to" in caplog.text
@pytest.mark.parametrize(
"invalid_input", [None, 123, ["list", "of", "items"]]
)
def test_invalid_task_input(invalid_input, mock_api_key):
model = TogetherModel()
response = model.run(invalid_input)
assert response is None
@patch("swarms.models.together_model.requests.post")
def test_run_streaming_enabled(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Generated response"}}]
}
mock_post.return_value = mock_response
model = TogetherModel(streaming_enabled=True)
task = "What is the color of the object?"
response = model.run(task)
assert response == "Generated response"
@patch("swarms.models.together_model.requests.post")
def test_run_empty_choices(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {"choices": []}
mock_post.return_value = mock_response
model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)
assert response is None
@patch("swarms.models.together_model.requests.post")
def test_run_with_exception(mock_post, mock_api_key):
mock_post.side_effect = Exception("Test exception")
model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)
assert response is None
def test_init_logging_disabled(monkeypatch):
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
model = TogetherModel()
assert model.logging_enabled is False
assert not model.system_prompt
Loading…
Cancel
Save