From 4bef09a252618f3802166d505de4f846094349d1 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 12 Dec 2023 16:49:46 -0800 Subject: [PATCH] [FEAT][TogertherModel] --- pyproject.toml | 2 +- swarms/models/together.py | 140 ++++++++++++++++++++++++++++++++++ swarms/prompts/react.py | 5 +- tests/models/test_togther.py | 144 +++++++++++++++++++++++++++++++++++ 4 files changed, 287 insertions(+), 4 deletions(-) create mode 100644 swarms/models/together.py create mode 100644 tests/models/test_togther.py diff --git a/pyproject.toml b/pyproject.toml index 8c53e85d..693ede3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.7.9" +version = "2.8.0" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/models/together.py b/swarms/models/together.py new file mode 100644 index 00000000..88949a5c --- /dev/null +++ b/swarms/models/together.py @@ -0,0 +1,140 @@ +import logging +import os +from typing import Optional + +import requests +from dotenv import load_dotenv + +from swarms.models.base_llm import AbstractLLM + +# Load environment variables +load_dotenv() + + +def together_api_key_env(): + """Get the API key from the environment.""" + return os.getenv("TOGETHER_API_KEY") + + +class TogetherModel(AbstractLLM): + """ + GPT-4 Vision API + + This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model. + + Parameters + ---------- + together_api_key : str + The OpenAI API key. Defaults to the together_api_key environment variable. + max_tokens : int + The maximum number of tokens to generate. Defaults to 300. + + + Methods + ------- + encode_image(img: str) + Encode image to base64. + run(task: str, img: str) + Run the model. + __call__(task: str, img: str) + Run the model. + + Examples: + --------- + >>> from swarms.models import GPT4VisionAPI + >>> llm = GPT4VisionAPI() + >>> task = "What is the color of the object?" + >>> img = "https://i.imgur.com/2M2ZGwC.jpeg" + >>> llm.run(task, img) + + + """ + + def __init__( + self, + together_api_key: str = together_api_key_env, + model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1", + logging_enabled: bool = False, + max_workers: int = 10, + max_tokens: str = 300, + api_endpoint: str = "https://api.together.xyz", + beautify: bool = False, + streaming_enabled: Optional[bool] = False, + meta_prompt: Optional[bool] = False, + system_prompt: Optional[str] = None, + *args, + **kwargs, + ): + super(TogetherModel).__init__(*args, **kwargs) + self.together_api_key = together_api_key + self.logging_enabled = logging_enabled + self.model_name = model_name + self.max_workers = max_workers + self.max_tokens = max_tokens + self.api_endpoint = api_endpoint + self.beautify = beautify + self.streaming_enabled = streaming_enabled + self.meta_prompt = meta_prompt + self.system_prompt = system_prompt + + if self.logging_enabled: + logging.basicConfig(level=logging.DEBUG) + else: + # Disable debug logs for requests and urllib3 + logging.getLogger("requests").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) + + if self.meta_prompt: + self.system_prompt = self.meta_prompt_init() + + # Function to handle vision tasks + def run(self, task: str = None, *args, **kwargs): + """Run the model.""" + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.together_api_key}", + } + payload = { + "model": self.model_name, + "messages": [ + { + "role": "system", + "content": [self.system_prompt], + }, + { + "role": "user", + "content": task, + }, + ], + "max_tokens": self.max_tokens, + **kwargs, + } + response = requests.post( + self.api_endpoint, + headers=headers, + json=payload, + *args, + **kwargs, + ) + + out = response.json() + if "choices" in out and out["choices"]: + content = ( + out["choices"][0] + .get("message", {}) + .get("content", None) + ) + if self.streaming_enabled: + content = self.stream_response(content) + return content + else: + print("No valid response in 'choices'") + return None + + except Exception as error: + print( + f"Error with the request: {error}, make sure you" + " double check input types and positions" + ) + return None diff --git a/swarms/prompts/react.py b/swarms/prompts/react.py index d4a8aeda..33dc8575 100644 --- a/swarms/prompts/react.py +++ b/swarms/prompts/react.py @@ -1,6 +1,5 @@ - def react_prompt(task: str = None): - REACT = f""" + PROMPT = f""" Task Description: Accomplish the following {task} using the reasoning guidelines below. @@ -56,4 +55,4 @@ def react_prompt(task: str = None): Remember, your goal is to provide a transparent and logical process that leads from observation to effective action. Your responses should demonstrate clear thinking, an understanding of the problem, and a rational approach to solving it. The use of tokens helps to structure your response and clarify the different stages of your reasoning and action. """ - return REACT \ No newline at end of file + return PROMPT diff --git a/tests/models/test_togther.py b/tests/models/test_togther.py new file mode 100644 index 00000000..75313a45 --- /dev/null +++ b/tests/models/test_togther.py @@ -0,0 +1,144 @@ +import os +import requests +import pytest +from unittest.mock import patch, Mock +from swarms.models.together import TogetherModel +import logging + + +@pytest.fixture +def mock_api_key(monkeypatch): + monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key") + + +def test_init_defaults(): + model = TogetherModel() + assert model.together_api_key == "mocked-api-key" + assert model.logging_enabled is False + assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1" + assert model.max_workers == 10 + assert model.max_tokens == 300 + assert model.api_endpoint == "https://api.together.xyz" + assert model.beautify is False + assert model.streaming_enabled is False + assert model.meta_prompt is False + assert model.system_prompt is None + + +def test_init_custom_params(mock_api_key): + model = TogetherModel( + together_api_key="custom-api-key", + logging_enabled=True, + model_name="custom-model", + max_workers=5, + max_tokens=500, + api_endpoint="https://custom-api.together.xyz", + beautify=True, + streaming_enabled=True, + meta_prompt="meta-prompt", + system_prompt="system-prompt", + ) + assert model.together_api_key == "custom-api-key" + assert model.logging_enabled is True + assert model.model_name == "custom-model" + assert model.max_workers == 5 + assert model.max_tokens == 500 + assert model.api_endpoint == "https://custom-api.together.xyz" + assert model.beautify is True + assert model.streaming_enabled is True + assert model.meta_prompt == "meta-prompt" + assert model.system_prompt == "system-prompt" + + +@patch("swarms.models.together_model.requests.post") +def test_run_success(mock_post, mock_api_key): + mock_response = Mock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "Generated response"}}] + } + mock_post.return_value = mock_response + + model = TogetherModel() + task = "What is the color of the object?" + response = model.run(task) + + assert response == "Generated response" + + +@patch("swarms.models.together_model.requests.post") +def test_run_failure(mock_post, mock_api_key): + mock_post.side_effect = requests.exceptions.RequestException( + "Request failed" + ) + + model = TogetherModel() + task = "What is the color of the object?" + response = model.run(task) + + assert response is None + + +def test_run_with_logging_enabled(caplog, mock_api_key): + model = TogetherModel(logging_enabled=True) + task = "What is the color of the object?" + + with caplog.at_level(logging.DEBUG): + model.run(task) + + assert "Sending request to" in caplog.text + + +@pytest.mark.parametrize( + "invalid_input", [None, 123, ["list", "of", "items"]] +) +def test_invalid_task_input(invalid_input, mock_api_key): + model = TogetherModel() + response = model.run(invalid_input) + + assert response is None + + +@patch("swarms.models.together_model.requests.post") +def test_run_streaming_enabled(mock_post, mock_api_key): + mock_response = Mock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "Generated response"}}] + } + mock_post.return_value = mock_response + + model = TogetherModel(streaming_enabled=True) + task = "What is the color of the object?" + response = model.run(task) + + assert response == "Generated response" + + +@patch("swarms.models.together_model.requests.post") +def test_run_empty_choices(mock_post, mock_api_key): + mock_response = Mock() + mock_response.json.return_value = {"choices": []} + mock_post.return_value = mock_response + + model = TogetherModel() + task = "What is the color of the object?" + response = model.run(task) + + assert response is None + + +@patch("swarms.models.together_model.requests.post") +def test_run_with_exception(mock_post, mock_api_key): + mock_post.side_effect = Exception("Test exception") + + model = TogetherModel() + task = "What is the color of the object?" + response = model.run(task) + + assert response is None + + +def test_init_logging_disabled(monkeypatch): + monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key") + model = TogetherModel() + assert model.logging_enabled is False + assert not model.system_prompt