[FEAT][prep_torch_inference]

pull/334/head
Kye 1 year ago
parent a9b91d4653
commit f39d722f2a

@ -8,7 +8,7 @@ from swarms.memory.base_vectordb import VectorDatabase
try: try:
import weaviate import weaviate
except ImportError as error: except ImportError:
print("pip install weaviate-client") print("pip install weaviate-client")

@ -395,7 +395,7 @@ class AbstractLLM(ABC):
float: _description_ float: _description_
""" """
start_time = time.time() start_time = time.time()
tokens = self.track_resource_utilization( self.track_resource_utilization(
prompt prompt
) # assuming `generate` is a method that generates tokens ) # assuming `generate` is a method that generates tokens
first_token_time = time.time() first_token_time = time.time()
@ -411,7 +411,7 @@ class AbstractLLM(ABC):
float: _description_ float: _description_
""" """
start_time = time.time() start_time = time.time()
tokens = self.run(prompt) self.run(prompt)
end_time = time.time() end_time = time.time()
return end_time - start_time return end_time - start_time
@ -426,6 +426,6 @@ class AbstractLLM(ABC):
""" """
start_time = time.time() start_time = time.time()
for prompt in prompts: for prompt in prompts:
tokens = self.run(prompt) self.run(prompt)
end_time = time.time() end_time = time.time()
return len(prompts) / (end_time - start_time) return len(prompts) / (end_time - start_time)

@ -8,6 +8,9 @@ from swarms.utils.math_eval import math_eval
from swarms.utils.llm_metrics_decorator import metrics_decorator from swarms.utils.llm_metrics_decorator import metrics_decorator
from swarms.utils.device_checker_cuda import check_device from swarms.utils.device_checker_cuda import check_device
from swarms.utils.load_model_torch import load_model_torch from swarms.utils.load_model_torch import load_model_torch
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,
)
__all__ = [ __all__ = [
"display_markdown_message", "display_markdown_message",
@ -18,4 +21,5 @@ __all__ = [
"metrics_decorator", "metrics_decorator",
"check_device", "check_device",
"load_model_torch", "load_model_torch",
"prep_torch_inference",
] ]

@ -0,0 +1,30 @@
import torch
from swarms.utils.load_model_torch import load_model_torch
def prep_torch_inference(
model_path: str = None,
device: torch.device = None,
*args,
**kwargs,
):
"""
Prepare a Torch model for inference.
Args:
model_path (str): Path to the model file.
device (torch.device): Device to run the model on.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
torch.nn.Module: The prepared model.
"""
try:
model = load_model_torch(model_path, device)
model.eval()
return model
except Exception as e:
# Add error handling code here
print(f"Error occurred while preparing Torch model: {e}")
return None

@ -65,11 +65,11 @@ def test_init_huggingface_llm():
assert llm.model_id == "test_model" assert llm.model_id == "test_model"
assert llm.device == "cuda" assert llm.device == "cuda"
assert llm.max_length == 1000 assert llm.max_length == 1000
assert llm.quantize == True assert llm.quantize is True
assert llm.quantization_config == {"config_key": "config_value"} assert llm.quantization_config == {"config_key": "config_value"}
assert llm.verbose == True assert llm.verbose is True
assert llm.distributed == True assert llm.distributed is True
assert llm.decoding == True assert llm.decoding is True
assert llm.max_workers == 3 assert llm.max_workers == 3
assert llm.repitition_penalty == 1.5 assert llm.repitition_penalty == 1.5
assert llm.no_repeat_ngram_size == 4 assert llm.no_repeat_ngram_size == 4
@ -90,7 +90,7 @@ def test_load_model(mock_huggingface_llm):
# Test running the model # Test running the model
def test_run(mock_huggingface_llm): def test_run(mock_huggingface_llm):
llm = HuggingfaceLLM(model_id="test_model") llm = HuggingfaceLLM(model_id="test_model")
result = llm.run("Test prompt") llm.run("Test prompt")
# Ensure that the run function is called # Ensure that the run function is called
assert True assert True

@ -25,14 +25,14 @@ def test_load_model_torch_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module) mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model) mocker.patch("torch.load", return_value=mock_model)
mocker.patch("torch.cuda.is_available", return_value=False) mocker.patch("torch.cuda.is_available", return_value=False)
model = load_model_torch("model_path") load_model_torch("model_path")
mock_model.to.assert_called_once_with(torch.device("cpu")) mock_model.to.assert_called_once_with(torch.device("cpu"))
def test_load_model_torch_device_specified(mocker): def test_load_model_torch_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module) mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model) mocker.patch("torch.load", return_value=mock_model)
model = load_model_torch( load_model_torch(
"model_path", device=torch.device("cuda") "model_path", device=torch.device("cuda")
) )
mock_model.to.assert_called_once_with(torch.device("cuda")) mock_model.to.assert_called_once_with(torch.device("cuda"))

@ -0,0 +1,50 @@
import torch
from unittest.mock import MagicMock
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,
)
def test_prep_torch_inference_no_model_path():
result = prep_torch_inference()
assert result is None
def test_prep_torch_inference_model_not_found(mocker):
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
side_effect=FileNotFoundError,
)
result = prep_torch_inference("non_existent_model_path")
assert result is None
def test_prep_torch_inference_runtime_error(mocker):
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
side_effect=RuntimeError,
)
result = prep_torch_inference("model_path")
assert result is None
def test_prep_torch_inference_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference("model_path")
mock_model.eval.assert_called_once()
def test_prep_torch_inference_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference(
"model_path", device=torch.device("cuda")
)
mock_model.eval.assert_called_once()
Loading…
Cancel
Save