parent
2bc46b8193
commit
1c4f0d8ad5
@ -1,74 +0,0 @@
|
|||||||
from typing import Dict, List, Optional
|
|
||||||
from dataclass import dataclass
|
|
||||||
|
|
||||||
from swarms.models import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OpenAIAssistant:
|
|
||||||
name: str = "OpenAI Assistant"
|
|
||||||
instructions: str = None
|
|
||||||
tools: List[Dict] = None
|
|
||||||
model: str = None
|
|
||||||
openai_api_key: str = None
|
|
||||||
temperature: float = 0.5
|
|
||||||
max_tokens: int = 100
|
|
||||||
stop: List[str] = None
|
|
||||||
echo: bool = False
|
|
||||||
stream: bool = False
|
|
||||||
log: bool = False
|
|
||||||
presence: bool = False
|
|
||||||
dashboard: bool = False
|
|
||||||
debug: bool = False
|
|
||||||
max_loops: int = 5
|
|
||||||
stopping_condition: Optional[str] = None
|
|
||||||
loop_interval: int = 1
|
|
||||||
retry_attempts: int = 3
|
|
||||||
retry_interval: int = 1
|
|
||||||
interactive: bool = False
|
|
||||||
dynamic_temperature: bool = False
|
|
||||||
state: Dict = None
|
|
||||||
response_filters: List = None
|
|
||||||
response_filter: Dict = None
|
|
||||||
response_filter_name: str = None
|
|
||||||
response_filter_value: str = None
|
|
||||||
response_filter_type: str = None
|
|
||||||
response_filter_action: str = None
|
|
||||||
response_filter_action_value: str = None
|
|
||||||
response_filter_action_type: str = None
|
|
||||||
response_filter_action_name: str = None
|
|
||||||
client = OpenAI()
|
|
||||||
role: str = "user"
|
|
||||||
instructions: str = None
|
|
||||||
|
|
||||||
def create_assistant(self, task: str):
|
|
||||||
assistant = self.client.create_assistant(
|
|
||||||
name=self.name,
|
|
||||||
instructions=self.instructions,
|
|
||||||
tools=self.tools,
|
|
||||||
model=self.model,
|
|
||||||
)
|
|
||||||
return assistant
|
|
||||||
|
|
||||||
def create_thread(self):
|
|
||||||
thread = self.client.beta.threads.create()
|
|
||||||
return thread
|
|
||||||
|
|
||||||
def add_message_to_thread(self, thread_id: str, message: str):
|
|
||||||
message = self.client.beta.threads.add_message(
|
|
||||||
thread_id=thread_id, role=self.user, content=message
|
|
||||||
)
|
|
||||||
return message
|
|
||||||
|
|
||||||
def run(self, task: str):
|
|
||||||
run = self.client.beta.threads.runs.create(
|
|
||||||
thread_id=self.create_thread().id,
|
|
||||||
assistant_id=self.create_assistant().id,
|
|
||||||
instructions=self.instructions,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.client.beta.threads.runs.retrieve(
|
|
||||||
thread_id=run.thread_id, run_id=run.id
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
@ -1,90 +1,215 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
from unittest.mock import Mock
|
import logging
|
||||||
from swarms.models.huggingface import HuggingFaceLLM
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_torch():
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_autotokenizer():
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
@pytest.fixture
|
from swarms.models.huggingface import HuggingfaceLLM
|
||||||
def mock_automodelforcausallm():
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
|
|
||||||
|
# Mock some functions and objects for testing
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_bitsandbytesconfig():
|
def mock_huggingface_llm(monkeypatch):
|
||||||
return Mock()
|
# Mock the model and tokenizer creation
|
||||||
|
def mock_init(
|
||||||
|
self,
|
||||||
@pytest.fixture
|
model_id,
|
||||||
def hugging_face_llm(
|
device="cpu",
|
||||||
mock_torch,
|
max_length=500,
|
||||||
mock_autotokenizer,
|
quantize=False,
|
||||||
mock_automodelforcausallm,
|
quantization_config=None,
|
||||||
mock_bitsandbytesconfig,
|
verbose=False,
|
||||||
):
|
distributed=False,
|
||||||
HuggingFaceLLM.torch = mock_torch
|
decoding=False,
|
||||||
HuggingFaceLLM.AutoTokenizer = mock_autotokenizer
|
max_workers=5,
|
||||||
HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm
|
repitition_penalty=1.3,
|
||||||
HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig
|
no_repeat_ngram_size=5,
|
||||||
|
temperature=0.7,
|
||||||
return HuggingFaceLLM(model_id="test")
|
top_k=40,
|
||||||
|
top_p=0.8,
|
||||||
|
):
|
||||||
def test_init(
|
pass
|
||||||
hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm
|
|
||||||
):
|
# Mock the model loading
|
||||||
assert hugging_face_llm.model_id == "test"
|
def mock_load_model(self):
|
||||||
mock_autotokenizer.from_pretrained.assert_called_once_with("test")
|
pass
|
||||||
mock_automodelforcausallm.from_pretrained.assert_called_once_with(
|
|
||||||
"test", quantization_config=None
|
# Mock the model generation
|
||||||
|
def mock_run(self, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(HuggingfaceLLM, "__init__", mock_init)
|
||||||
|
monkeypatch.setattr(HuggingfaceLLM, "load_model", mock_load_model)
|
||||||
|
monkeypatch.setattr(HuggingfaceLLM, "run", mock_run)
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for initialization and attribute settings
|
||||||
|
def test_init_huggingface_llm():
|
||||||
|
llm = HuggingfaceLLM(
|
||||||
|
model_id="test_model",
|
||||||
|
device="cuda",
|
||||||
|
max_length=1000,
|
||||||
|
quantize=True,
|
||||||
|
quantization_config={"config_key": "config_value"},
|
||||||
|
verbose=True,
|
||||||
|
distributed=True,
|
||||||
|
decoding=True,
|
||||||
|
max_workers=3,
|
||||||
|
repitition_penalty=1.5,
|
||||||
|
no_repeat_ngram_size=4,
|
||||||
|
temperature=0.8,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.7,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert llm.model_id == "test_model"
|
||||||
def test_init_with_quantize(
|
assert llm.device == "cuda"
|
||||||
hugging_face_llm,
|
assert llm.max_length == 1000
|
||||||
mock_autotokenizer,
|
assert llm.quantize == True
|
||||||
mock_automodelforcausallm,
|
assert llm.quantization_config == {"config_key": "config_value"}
|
||||||
mock_bitsandbytesconfig,
|
assert llm.verbose == True
|
||||||
):
|
assert llm.distributed == True
|
||||||
quantization_config = {
|
assert llm.decoding == True
|
||||||
"load_in_4bit": True,
|
assert llm.max_workers == 3
|
||||||
"bnb_4bit_use_double_quant": True,
|
assert llm.repitition_penalty == 1.5
|
||||||
|
assert llm.no_repeat_ngram_size == 4
|
||||||
|
assert llm.temperature == 0.8
|
||||||
|
assert llm.top_k == 50
|
||||||
|
assert llm.top_p == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
# Test loading the model
|
||||||
|
def test_load_model(mock_huggingface_llm):
|
||||||
|
llm = HuggingfaceLLM(model_id="test_model")
|
||||||
|
llm.load_model()
|
||||||
|
|
||||||
|
# Ensure that the load_model function is called
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
# Test running the model
|
||||||
|
def test_run(mock_huggingface_llm):
|
||||||
|
llm = HuggingfaceLLM(model_id="test_model")
|
||||||
|
result = llm.run("Test prompt")
|
||||||
|
|
||||||
|
# Ensure that the run function is called
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
# Test for setting max_length
|
||||||
|
def test_llm_set_max_length(llm_instance):
|
||||||
|
new_max_length = 1000
|
||||||
|
llm_instance.set_max_length(new_max_length)
|
||||||
|
assert llm_instance.max_length == new_max_length
|
||||||
|
|
||||||
|
# Test for setting verbose
|
||||||
|
def test_llm_set_verbose(llm_instance):
|
||||||
|
llm_instance.set_verbose(True)
|
||||||
|
assert llm_instance.verbose is True
|
||||||
|
|
||||||
|
# Test for setting distributed
|
||||||
|
def test_llm_set_distributed(llm_instance):
|
||||||
|
llm_instance.set_distributed(True)
|
||||||
|
assert llm_instance.distributed is True
|
||||||
|
|
||||||
|
# Test for setting decoding
|
||||||
|
def test_llm_set_decoding(llm_instance):
|
||||||
|
llm_instance.set_decoding(True)
|
||||||
|
assert llm_instance.decoding is True
|
||||||
|
|
||||||
|
# Test for setting max_workers
|
||||||
|
def test_llm_set_max_workers(llm_instance):
|
||||||
|
new_max_workers = 10
|
||||||
|
llm_instance.set_max_workers(new_max_workers)
|
||||||
|
assert llm_instance.max_workers == new_max_workers
|
||||||
|
|
||||||
|
# Test for setting repitition_penalty
|
||||||
|
def test_llm_set_repitition_penalty(llm_instance):
|
||||||
|
new_repitition_penalty = 1.5
|
||||||
|
llm_instance.set_repitition_penalty(new_repitition_penalty)
|
||||||
|
assert llm_instance.repitition_penalty == new_repitition_penalty
|
||||||
|
|
||||||
|
# Test for setting no_repeat_ngram_size
|
||||||
|
def test_llm_set_no_repeat_ngram_size(llm_instance):
|
||||||
|
new_no_repeat_ngram_size = 6
|
||||||
|
llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size)
|
||||||
|
assert llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size
|
||||||
|
|
||||||
|
# Test for setting temperature
|
||||||
|
def test_llm_set_temperature(llm_instance):
|
||||||
|
new_temperature = 0.8
|
||||||
|
llm_instance.set_temperature(new_temperature)
|
||||||
|
assert llm_instance.temperature == new_temperature
|
||||||
|
|
||||||
|
# Test for setting top_k
|
||||||
|
def test_llm_set_top_k(llm_instance):
|
||||||
|
new_top_k = 50
|
||||||
|
llm_instance.set_top_k(new_top_k)
|
||||||
|
assert llm_instance.top_k == new_top_k
|
||||||
|
|
||||||
|
# Test for setting top_p
|
||||||
|
def test_llm_set_top_p(llm_instance):
|
||||||
|
new_top_p = 0.9
|
||||||
|
llm_instance.set_top_p(new_top_p)
|
||||||
|
assert llm_instance.top_p == new_top_p
|
||||||
|
|
||||||
|
# Test for setting quantize
|
||||||
|
def test_llm_set_quantize(llm_instance):
|
||||||
|
llm_instance.set_quantize(True)
|
||||||
|
assert llm_instance.quantize is True
|
||||||
|
|
||||||
|
# Test for setting quantization_config
|
||||||
|
def test_llm_set_quantization_config(llm_instance):
|
||||||
|
new_quantization_config = {
|
||||||
|
"load_in_4bit": False,
|
||||||
|
"bnb_4bit_use_double_quant": False,
|
||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||||
}
|
}
|
||||||
mock_bitsandbytesconfig.return_value = quantization_config
|
llm_instance.set_quantization_config(new_quantization_config)
|
||||||
|
assert llm_instance.quantization_config == new_quantization_config
|
||||||
HuggingFaceLLM(model_id="test", quantize=True)
|
|
||||||
|
# Test for setting model_id
|
||||||
mock_bitsandbytesconfig.assert_called_once_with(
|
def test_llm_set_model_id(llm_instance):
|
||||||
**quantization_config
|
new_model_id = "EleutherAI/gpt-neo-2.7B"
|
||||||
)
|
llm_instance.set_model_id(new_model_id)
|
||||||
mock_autotokenizer.from_pretrained.assert_called_once_with("test")
|
assert llm_instance.model_id == new_model_id
|
||||||
mock_automodelforcausallm.from_pretrained.assert_called_once_with(
|
|
||||||
"test", quantization_config=quantization_config
|
# Test for setting model
|
||||||
)
|
@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained")
|
||||||
|
def test_llm_set_model(mock_model, llm_instance):
|
||||||
|
mock_model.return_value = "mocked model"
|
||||||
def test_generate_text(hugging_face_llm):
|
llm_instance.set_model(mock_model)
|
||||||
prompt_text = "test prompt"
|
assert llm_instance.model == "mocked model"
|
||||||
expected_output = "test output"
|
|
||||||
hugging_face_llm.tokenizer.encode.return_value = torch.tensor(
|
# Test for setting tokenizer
|
||||||
[0]
|
@patch("swarms.models.huggingface.AutoTokenizer.from_pretrained")
|
||||||
) # Mock tensor
|
def test_llm_set_tokenizer(mock_tokenizer, llm_instance):
|
||||||
hugging_face_llm.model.generate.return_value = torch.tensor(
|
mock_tokenizer.return_value = "mocked tokenizer"
|
||||||
[0]
|
llm_instance.set_tokenizer(mock_tokenizer)
|
||||||
) # Mock tensor
|
assert llm_instance.tokenizer == "mocked tokenizer"
|
||||||
hugging_face_llm.tokenizer.decode.return_value = expected_output
|
|
||||||
|
# Test for setting logger
|
||||||
output = hugging_face_llm.generate_text(prompt_text)
|
def test_llm_set_logger(llm_instance):
|
||||||
|
new_logger = logging.getLogger("test_logger")
|
||||||
assert output == expected_output
|
llm_instance.set_logger(new_logger)
|
||||||
|
assert llm_instance.logger == new_logger
|
||||||
|
|
||||||
|
# Test for saving model
|
||||||
|
@patch("torch.save")
|
||||||
|
def test_llm_save_model(mock_save, llm_instance):
|
||||||
|
llm_instance.save_model("path/to/save")
|
||||||
|
mock_save.assert_called_once()
|
||||||
|
|
||||||
|
# Test for print_dashboard
|
||||||
|
@patch("builtins.print")
|
||||||
|
def test_llm_print_dashboard(mock_print, llm_instance):
|
||||||
|
llm_instance.print_dashboard("test task")
|
||||||
|
mock_print.assert_called()
|
||||||
|
|
||||||
|
# Test for __call__ method
|
||||||
|
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||||
|
def test_llm_call(mock_run, llm_instance):
|
||||||
|
mock_run.return_value = "mocked output"
|
||||||
|
result = llm_instance("test task")
|
||||||
|
assert result == "mocked output"
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from swarms.swarms.base import BaseStructure
|
from swarms.structs.base import BaseStructure
|
||||||
|
|
||||||
|
|
||||||
class TestBaseStructure:
|
class TestBaseStructure:
|
Loading…
Reference in new issue