From 79fac05df477327a4a6ba70417e8692172c148b8 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 4 Nov 2023 21:26:07 -0400 Subject: [PATCH] anthropic tests Former-commit-id: d4bd4fa4a47eaeba44a08164c0f464f3aaa24dcb --- swarms/models/anthropic.py | 24 +++++++- tests/models/anthropic.py | 116 +++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 tests/models/anthropic.py diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index 232ff647..e2066637 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -7,9 +7,31 @@ class Anthropic: Anthropic large language models. - Args: + model: The model to use. Defaults to "claude-2". + max_tokens_to_sample: The maximum number of tokens to sample. + temperature: The temperature to use for sampling. + top_k: The top_k to use for sampling. + top_p: The top_p to use for sampling. + streaming: Whether to stream the response or not. + default_request_timeout: The default request timeout to use. + + + Attributes: + model: The model to use. + max_tokens_to_sample: The maximum number of tokens to sample. + temperature: The temperature to use for sampling. + top_k: The top_k to use for sampling. + top_p: The top_p to use for sampling. + streaming: Whether to stream the response or not. + default_request_timeout: The default request timeout to use. + anthropic_api_url: The API URL to use. + anthropic_api_key: The API key to use. + Usage: + model_wrapper = Anthropic() + completion = model_wrapper("Hello, my name is") + print(completion) """ diff --git a/tests/models/anthropic.py b/tests/models/anthropic.py new file mode 100644 index 00000000..844415aa --- /dev/null +++ b/tests/models/anthropic.py @@ -0,0 +1,116 @@ +import os +import pytest +from unittest.mock import Mock, patch +from swarms.models.anthropic import Anthropic + +@pytest.fixture +def mock_anthropic_env(): + os.environ["ANTHROPIC_API_URL"] = "https://test.anthropic.com" + os.environ["ANTHROPIC_API_KEY"] = "test_api_key" + yield + del os.environ["ANTHROPIC_API_URL"] + del os.environ["ANTHROPIC_API_KEY"] + +@pytest.fixture +def mock_requests_post(): + with patch("requests.post") as mock_post: + yield mock_post + +@pytest.fixture +def anthropic_instance(): + return Anthropic(model="test-model") + +def test_anthropic_init_default_values(anthropic_instance): + assert anthropic_instance.model == "test-model" + assert anthropic_instance.max_tokens_to_sample == 256 + assert anthropic_instance.temperature is None + assert anthropic_instance.top_k is None + assert anthropic_instance.top_p is None + assert anthropic_instance.streaming is False + assert anthropic_instance.default_request_timeout == 600 + assert anthropic_instance.anthropic_api_url == "https://test.anthropic.com" + assert anthropic_instance.anthropic_api_key == "test_api_key" + +def test_anthropic_init_custom_values(): + anthropic_instance = Anthropic( + model="custom-model", + max_tokens_to_sample=128, + temperature=0.8, + top_k=5, + top_p=0.9, + streaming=True, + default_request_timeout=300, + ) + assert anthropic_instance.model == "custom-model" + assert anthropic_instance.max_tokens_to_sample == 128 + assert anthropic_instance.temperature == 0.8 + assert anthropic_instance.top_k == 5 + assert anthropic_instance.top_p == 0.9 + assert anthropic_instance.streaming is True + assert anthropic_instance.default_request_timeout == 300 + +def test_anthropic_default_params(anthropic_instance): + default_params = anthropic_instance._default_params() + assert default_params == { + "max_tokens_to_sample": 256, + "model": "test-model", + } + +def test_anthropic_run(mock_anthropic_env, mock_requests_post, anthropic_instance): + mock_response = Mock() + mock_response.json.return_value = {"completion": "Generated text"} + mock_requests_post.return_value = mock_response + + task = "Generate text" + stop = ["stop1", "stop2"] + + completion = anthropic_instance.run(task, stop) + + assert completion == "Generated text" + mock_requests_post.assert_called_once_with( + "https://test.anthropic.com/completions", + headers={"Authorization": "Bearer test_api_key"}, + json={ + "prompt": task, + "stop_sequences": stop, + "max_tokens_to_sample": 256, + "model": "test-model", + }, + timeout=600, + ) + +def test_anthropic_call(mock_anthropic_env, mock_requests_post, anthropic_instance): + mock_response = Mock() + mock_response.json.return_value = {"completion": "Generated text"} + mock_requests_post.return_value = mock_response + + task = "Generate text" + stop = ["stop1", "stop2"] + + completion = anthropic_instance(task, stop) + + assert completion == "Generated text" + mock_requests_post.assert_called_once_with( + "https://test.anthropic.com/completions", + headers={"Authorization": "Bearer test_api_key"}, + json={ + "prompt": task, + "stop_sequences": stop, + "max_tokens_to_sample": 256, + "model": "test-model", + }, + timeout=600, + ) + +def test_anthropic_exception_handling(mock_anthropic_env, mock_requests_post, anthropic_instance): + mock_response = Mock() + mock_response.json.return_value = {"error": "An error occurred"} + mock_requests_post.return_value = mock_response + + task = "Generate text" + stop = ["stop1", "stop2"] + + with pytest.raises(Exception) as excinfo: + anthropic_instance(task, stop) + + assert "An error occurred" in str(excinfo.value)