|
|
@ -1,7 +1,7 @@
|
|
|
|
import requests
|
|
|
|
import requests
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
from unittest.mock import patch, Mock
|
|
|
|
from unittest.mock import patch, Mock
|
|
|
|
from swarms.models.together import TogetherModel
|
|
|
|
from swarms.models.together import TogetherLLM
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -11,7 +11,7 @@ def mock_api_key(monkeypatch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_init_defaults():
|
|
|
|
def test_init_defaults():
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
assert model.together_api_key == "mocked-api-key"
|
|
|
|
assert model.together_api_key == "mocked-api-key"
|
|
|
|
assert model.logging_enabled is False
|
|
|
|
assert model.logging_enabled is False
|
|
|
|
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
|
|
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
|
@ -25,7 +25,7 @@ def test_init_defaults():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_init_custom_params(mock_api_key):
|
|
|
|
def test_init_custom_params(mock_api_key):
|
|
|
|
model = TogetherModel(
|
|
|
|
model = TogetherLLM(
|
|
|
|
together_api_key="custom-api-key",
|
|
|
|
together_api_key="custom-api-key",
|
|
|
|
logging_enabled=True,
|
|
|
|
logging_enabled=True,
|
|
|
|
model_name="custom-model",
|
|
|
|
model_name="custom-model",
|
|
|
@ -57,7 +57,7 @@ def test_run_success(mock_post, mock_api_key):
|
|
|
|
}
|
|
|
|
}
|
|
|
|
mock_post.return_value = mock_response
|
|
|
|
mock_post.return_value = mock_response
|
|
|
|
|
|
|
|
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
response = model.run(task)
|
|
|
|
response = model.run(task)
|
|
|
|
|
|
|
|
|
|
|
@ -70,7 +70,7 @@ def test_run_failure(mock_post, mock_api_key):
|
|
|
|
"Request failed"
|
|
|
|
"Request failed"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
response = model.run(task)
|
|
|
|
response = model.run(task)
|
|
|
|
|
|
|
|
|
|
|
@ -78,7 +78,7 @@ def test_run_failure(mock_post, mock_api_key):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_run_with_logging_enabled(caplog, mock_api_key):
|
|
|
|
def test_run_with_logging_enabled(caplog, mock_api_key):
|
|
|
|
model = TogetherModel(logging_enabled=True)
|
|
|
|
model = TogetherLLM(logging_enabled=True)
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
|
|
|
|
|
|
|
|
with caplog.at_level(logging.DEBUG):
|
|
|
|
with caplog.at_level(logging.DEBUG):
|
|
|
@ -91,7 +91,7 @@ def test_run_with_logging_enabled(caplog, mock_api_key):
|
|
|
|
"invalid_input", [None, 123, ["list", "of", "items"]]
|
|
|
|
"invalid_input", [None, 123, ["list", "of", "items"]]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
def test_invalid_task_input(invalid_input, mock_api_key):
|
|
|
|
def test_invalid_task_input(invalid_input, mock_api_key):
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
response = model.run(invalid_input)
|
|
|
|
response = model.run(invalid_input)
|
|
|
|
|
|
|
|
|
|
|
|
assert response is None
|
|
|
|
assert response is None
|
|
|
@ -105,7 +105,7 @@ def test_run_streaming_enabled(mock_post, mock_api_key):
|
|
|
|
}
|
|
|
|
}
|
|
|
|
mock_post.return_value = mock_response
|
|
|
|
mock_post.return_value = mock_response
|
|
|
|
|
|
|
|
|
|
|
|
model = TogetherModel(streaming_enabled=True)
|
|
|
|
model = TogetherLLM(streaming_enabled=True)
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
response = model.run(task)
|
|
|
|
response = model.run(task)
|
|
|
|
|
|
|
|
|
|
|
@ -118,7 +118,7 @@ def test_run_empty_choices(mock_post, mock_api_key):
|
|
|
|
mock_response.json.return_value = {"choices": []}
|
|
|
|
mock_response.json.return_value = {"choices": []}
|
|
|
|
mock_post.return_value = mock_response
|
|
|
|
mock_post.return_value = mock_response
|
|
|
|
|
|
|
|
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
response = model.run(task)
|
|
|
|
response = model.run(task)
|
|
|
|
|
|
|
|
|
|
|
@ -129,7 +129,7 @@ def test_run_empty_choices(mock_post, mock_api_key):
|
|
|
|
def test_run_with_exception(mock_post, mock_api_key):
|
|
|
|
def test_run_with_exception(mock_post, mock_api_key):
|
|
|
|
mock_post.side_effect = Exception("Test exception")
|
|
|
|
mock_post.side_effect = Exception("Test exception")
|
|
|
|
|
|
|
|
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
task = "What is the color of the object?"
|
|
|
|
response = model.run(task)
|
|
|
|
response = model.run(task)
|
|
|
|
|
|
|
|
|
|
|
@ -138,6 +138,6 @@ def test_run_with_exception(mock_post, mock_api_key):
|
|
|
|
|
|
|
|
|
|
|
|
def test_init_logging_disabled(monkeypatch):
|
|
|
|
def test_init_logging_disabled(monkeypatch):
|
|
|
|
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
|
|
|
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
|
|
|
model = TogetherModel()
|
|
|
|
model = TogetherLLM()
|
|
|
|
assert model.logging_enabled is False
|
|
|
|
assert model.logging_enabled is False
|
|
|
|
assert not model.system_prompt
|
|
|
|
assert not model.system_prompt
|
|
|
|