[FEATS][Tokenizers] [TimmModel] [Odin] [UltralyticsModel]

pull/362/head
Kye 12 months ago
parent d70398806d
commit e29ec9c943

@ -20,52 +20,4 @@ workflow.add(task)
workflow.run() workflow.run()
``` ```
Returns: None
#### Source Code:
```python
class RecursiveWorkflow(BaseStructure):
def __init__(self, stop_token: str = "<DONE>"):
"""
Args:
stop_token (str, optional): The token that indicates when to stop the workflow. Default is "<DONE>".
The stop_token indicates the value at which the current workflow is finished.
"""
self.stop_token = stop_token
self.tasks = []
assert (
self.stop_token is not None
), "stop_token cannot be None"
def add(self, task: Task, tasks: List[Task] = None):
"""Adds a task to the workflow.
Args:
task (Task): The task to be added.
tasks (List[Task], optional): List of tasks to be executed.
"""
try:
if tasks:
for task in tasks:
self.tasks.append(task)
else:
self.tasks.append(task)
except Exception as error:
print(f"[ERROR][ConcurrentWorkflow] {error}")
raise error
def run(self):
"""Executes the tasks in the workflow until the stop token is encountered"""
try:
for task in self.tasks:
while True:
result = task.execute()
if self.stop_token in result:
break
except Exception as error:
print(f"[ERROR][RecursiveWorkflow] {error}")
raise error
```
In summary, the `RecursiveWorkflow` class is designed to automate tasks by adding and executing these tasks recursively until a stopping condition is reached. This can be achieved by utilizing the `add` and `run` methods provided. A general format for adding and utilizing the `RecursiveWorkflow` class has been provided under the "Examples" section. If you require any further information, view other sections, like Args and Source Code for specifics on using the class effectively. In summary, the `RecursiveWorkflow` class is designed to automate tasks by adding and executing these tasks recursively until a stopping condition is reached. This can be achieved by utilizing the `add` and `run` methods provided. A general format for adding and utilizing the `RecursiveWorkflow` class has been provided under the "Examples" section. If you require any further information, view other sections, like Args and Source Code for specifics on using the class effectively.

@ -11,8 +11,8 @@
from swarms.structs import Task, Agent from swarms.structs import Task, Agent
from swarms.models import OpenAIChat from swarms.models import OpenAIChat
agent = Agent(llm=OpenAIChat(openai_api_key=""), max_loops=1, dashboard=False) agent = Agent(llm=OpenAIChat(openai_api_key=""), max_loops=1, dashboard=False)
task = Task(description="What's the weather in miami", agent=agent) task = Task(agent=agent)
task.execute() task.execute("What's the weather in miami")
print(task.result) print(task.result)
# Example 2: Adding a dependency and setting priority # Example 2: Adding a dependency and setting priority

@ -1,10 +1,14 @@
from swarms import OpenAITTS from swarms import OpenAITTS
import os
from dotenv import load_dotenv
load_dotenv()
tts = OpenAITTS( tts = OpenAITTS(
model_name="tts-1-1106", model_name="tts-1-1106",
voice="onyx", voice="onyx",
openai_api_key="YOUR_API_KEY", openai_api_key=os.getenv("OPENAI_API_KEY")
) )
out = tts.run_and_save("pliny is a girl and a chicken") out = tts.run_and_save("Dammmmmm those tacos were good")
print(out) print(out)

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "3.6.8" version = "3.7.3"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]
@ -31,7 +31,7 @@ asyncio = "3.4.3"
einops = "0.7.0" einops = "0.7.0"
google-generativeai = "0.3.1" google-generativeai = "0.3.1"
langchain-experimental = "0.0.10" langchain-experimental = "0.0.10"
playwright = "1.34.0" tensorflow = "*"
weaviate-client = "3.25.3" weaviate-client = "3.25.3"
opencv-python-headless = "4.8.1.78" opencv-python-headless = "4.8.1.78"
faiss-cpu = "1.7.4" faiss-cpu = "1.7.4"
@ -44,14 +44,12 @@ PyPDF2 = "3.0.1"
accelerate = "*" accelerate = "*"
sentencepiece = "0.1.98" sentencepiece = "0.1.98"
wget = "3.2" wget = "3.2"
tensorflow = "2.14.0"
httpx = "0.24.1" httpx = "0.24.1"
tiktoken = "0.4.0" tiktoken = "0.4.0"
safetensors = "0.3.3" safetensors = "0.3.3"
attrs = "22.2.0" attrs = "22.2.0"
ggl = "1.1.0" ggl = "1.1.0"
ratelimit = "2.2.1" ratelimit = "2.2.1"
beautifulsoup4 = "4.11.2"
cohere = "4.24" cohere = "4.24"
huggingface-hub = "*" huggingface-hub = "*"
pydantic = "1.10.12" pydantic = "1.10.12"
@ -74,6 +72,7 @@ peft = "*"
psutil = "*" psutil = "*"
ultralytics = "*" ultralytics = "*"
timm = "*" timm = "*"
supervision = "*"

@ -22,7 +22,7 @@ requests_mock
PyPDF2==3.0.1 PyPDF2==3.0.1
accelerate==0.22.0 accelerate==0.22.0
chromadb==0.4.14 chromadb==0.4.14
tensorflow==2.14.0 tensorflow
optimum optimum
tiktoken==0.4.0 tiktoken==0.4.0
tabulate==0.9.0 tabulate==0.9.0
@ -60,4 +60,5 @@ mkdocs-glightbox
pre-commit==3.2.2 pre-commit==3.2.2
peft peft
psutil psutil
ultralytics ultralytics
supervision

@ -41,7 +41,9 @@ from swarms.models.gemini import Gemini # noqa: E402
from swarms.models.gigabind import Gigabind # noqa: E402 from swarms.models.gigabind import Gigabind # noqa: E402
from swarms.models.zeroscope import ZeroscopeTTV # noqa: E402 from swarms.models.zeroscope import ZeroscopeTTV # noqa: E402
from swarms.models.timm import TimmModel # noqa: E402 from swarms.models.timm import TimmModel # noqa: E402
from swarms.models.ultralytics_model import UltralyticsModel # noqa: E402 from swarms.models.ultralytics_model import (
UltralyticsModel,
) # noqa: E402
# from swarms.models.dalle3 import Dalle3 # from swarms.models.dalle3 import Dalle3
@ -51,6 +53,9 @@ from swarms.models.ultralytics_model import UltralyticsModel # noqa: E402
# from swarms.models.cog_agent import CogAgent # noqa: E402 # from swarms.models.cog_agent import CogAgent # noqa: E402
################# Tokenizers
############## Types ############## Types
from swarms.models.types import ( from swarms.models.types import (
TextModality, TextModality,

@ -0,0 +1,90 @@
import supervision as sv
from ultraanalytics import YOLO
from tqdm import tqdm
from swarms.models.base_llm import AbstractLLM
class Odin(AbstractLLM):
"""
Odin class represents an object detection and tracking model.
Args:
source_weights_path (str): Path to the weights file for the object detection model.
source_video_path (str): Path to the source video file.
target_video_path (str): Path to save the output video file.
confidence_threshold (float): Confidence threshold for object detection.
iou_threshold (float): Intersection over Union (IoU) threshold for object detection.
Attributes:
source_weights_path (str): Path to the weights file for the object detection model.
source_video_path (str): Path to the source video file.
target_video_path (str): Path to save the output video file.
confidence_threshold (float): Confidence threshold for object detection.
iou_threshold (float): Intersection over Union (IoU) threshold for object detection.
"""
def __init__(
self,
source_weights_path: str = None,
target_video_path: str = None,
confidence_threshold: float = 0.3,
iou_threshold: float = 0.7,
):
super(Odin, self).__init__()
self.source_weights_path = source_weights_path
self.target_video_path = target_video_path
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
def run(self, video_path: str, *args, **kwargs):
"""
Runs the object detection and tracking algorithm on the specified video.
Args:
video_path (str): The path to the input video file.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
bool: True if the video was processed successfully, False otherwise.
"""
model = YOLO(self.source_weights_path)
tracker = sv.ByteTrack()
box_annotator = sv.BoxAnnotator()
frame_generator = sv.get_video_frames_generator(
source_path=self.source_video_path
)
video_info = sv.VideoInfo.from_video_path(
video_path=video_path
)
with sv.VideoSink(
target_path=self.target_video_path, video_info=video_info
) as sink:
for frame in tqdm(
frame_generator, total=video_info.total_frames
):
results = model(
frame,
verbose=True,
conf=self.confidence_threshold,
iou=self.iou_threshold,
)[0]
detections = sv.Detections.from_ultranalytics(results)
detections = tracker.update_with_detections(
detections
)
labels = [
f"#{tracker_id} {model.model.names[class_id]}"
for _, _, _, class_id, tracker_id in detections
]
annotated_frame = box_annotator.annotate(
scene=frame.copy(),
detections=detections,
labels=labels,
)
result = sink.write_frame(frame=annotated_frame)
return result

@ -0,0 +1,422 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import os.path as osp
from collections import deque
from typing import List, Optional, Sequence, Union
import torch
from swarms.utils.get_logger import get_logger
class SentencePieceTokenizer:
"""Tokenizer of sentencepiece.
Args:
model_file (str): the path of the tokenizer model
"""
def __init__(self, model_file: str):
from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file)
self._prefix_space_tokens = None
# for stop words
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.logger = get_logger("lmdeploy")
@property
def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size()
@property
def bos_token_id(self):
"""begine of the sentence token id."""
return self.model.bos_id()
@property
def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_id()
@property
def prefix_space_tokens(self):
"""tokens without prefix space."""
if self._prefix_space_tokens is None:
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
self._prefix_space_tokens = {
i
for i, tok in enumerate(vocab)
if tok.startswith("")
}
return self._prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if (
len(tokens)
and not decoded.startswith(" ")
and tokens[0] in self.prefix_space_tokens
):
return " " + decoded
else:
return decoded
def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if token == " ": # ' ' is special
token = ""
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
indexes = [i for i, voc in enumerate(vocab) if token in voc]
if len(indexes) > self.max_indexes_num:
indexes = self.encode(token, add_bos=False)[-1:]
self.logger.warning(
f"There are too many(>{self.max_indexes_num})"
f" possible indexes may decoding {token}, we will use"
f" {indexes} only"
)
self._indexes_tokens_deque.append((token, indexes))
return indexes
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
return self.model.Encode(s, add_bos=add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
if isinstance(t, torch.Tensor):
t = t.tolist()
t = t[offset:]
out_string = self.model.Decode(t)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Args:
s (str): prompts
Returns:
list[int]: token ids
"""
import addict
add_bos = False
add_eos = False
input_ids = self.model.Encode(
s, add_bos=add_bos, add_eos=add_eos
)
return addict.Addict(input_ids=input_ids)
class HuggingFaceTokenizer:
"""Tokenizer of sentencepiece.
Args:
model_dir (str): the directory of the tokenizer model
"""
def __init__(self, model_dir: str):
from transformers import AutoTokenizer
model_file = osp.join(model_dir, "tokenizer.model")
backend_tokenizer_file = osp.join(model_dir, "tokenizer.json")
model_file_exists = osp.exists(model_file)
self.logger = get_logger("lmdeploy")
if (
not osp.exists(backend_tokenizer_file)
and model_file_exists
):
self.logger.warning(
"Can not find tokenizer.json. "
"It may take long time to initialize the tokenizer."
)
self.model = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True
)
self._prefix_space_tokens = None
# save tokenizer.json to reuse
if (
not osp.exists(backend_tokenizer_file)
and model_file_exists
):
if hasattr(self.model, "backend_tokenizer"):
if os.access(model_dir, os.W_OK):
self.model.backend_tokenizer.save(
backend_tokenizer_file
)
if self.model.eos_token_id is None:
generation_config_file = osp.join(
model_dir, "generation_config.json"
)
if osp.exists(generation_config_file):
with open(generation_config_file, "r") as f:
cfg = json.load(f)
self.model.eos_token_id = cfg["eos_token_id"]
elif hasattr(self.model, "eod_id"): # Qwen remote
self.model.eos_token_id = self.model.eod_id
# for stop words
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.token2id = {}
@property
def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size
@property
def bos_token_id(self):
"""begine of the sentence token id."""
return self.model.bos_token_id
@property
def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_token_id
@property
def prefix_space_tokens(self):
"""tokens without prefix space."""
if self._prefix_space_tokens is None:
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size))
)
self._prefix_space_tokens = {
i
for i, tok in enumerate(vocab)
if tok.startswith(
"" if isinstance(tok, str) else b" "
)
}
return self._prefix_space_tokens
def _maybe_add_prefix_space(
self, tokens: List[int], decoded: str
):
"""maybe add prefix space for incremental decoding."""
if (
len(tokens)
and not decoded.startswith(" ")
and tokens[0] in self.prefix_space_tokens
):
return " " + decoded
else:
return decoded
@property
def maybe_decode_bytes(self):
"""Check if self.model.convert_ids_to_tokens return not a str value."""
if self._maybe_decode_bytes is None:
self._maybe_decode_bytes = False
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size))
)
for tok in vocab:
if not isinstance(tok, str):
self._maybe_decode_bytes = True
break
return self._maybe_decode_bytes
def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if self.token2id == {}:
# decode is slower than convert_ids_to_tokens
if self.maybe_decode_bytes:
self.token2id = {
self.model.decode(i): i
for i in range(self.vocab_size)
}
else:
self.token2id = {
self.model.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
if token == " ": # ' ' is special
token = ""
indexes = [
i
for _token, i in self.token2id.items()
if token in _token
]
if len(indexes) > self.max_indexes_num:
indexes = self.encode(token, add_bos=False)[-1:]
self.logger.warning(
f"There are too many(>{self.max_indexes_num})"
f" possible indexes may decoding {token}, we will use"
f" {indexes} only"
)
self._indexes_tokens_deque.append((token, indexes))
return indexes
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
encoded = self.model.encode(s, **kwargs)
if not add_bos:
# in the middle of a session
if len(encoded) and encoded[0] == self.bos_token_id:
encoded = encoded[1:]
return encoded
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
skip_special_tokens = True
t = t[offset:]
out_string = self.model.decode(
t, skip_special_tokens=skip_special_tokens
)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Args:
s (str): prompts
Returns:
list[int]: token ids
"""
add_special_tokens = False
return self.model(s, add_special_tokens=add_special_tokens)
class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Args:
model_file (str): the path of the tokenizer model
"""
def __init__(self, model_file: str):
if model_file.endswith(".model"):
model_folder = osp.split(model_file)[0]
else:
model_folder = model_file
model_file = osp.join(model_folder, "tokenizer.model")
tokenizer_config_file = osp.join(
model_folder, "tokenizer_config.json"
)
model_file_exists = osp.exists(model_file)
config_exists = osp.exists(tokenizer_config_file)
use_hf_model = config_exists or not model_file_exists
self.logger = get_logger("lmdeploy")
if not use_hf_model:
self.model = SentencePieceTokenizer(model_file)
else:
self.model = HuggingFaceTokenizer(model_folder)
@property
def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size
@property
def bos_token_id(self):
"""begine of the sentence token id."""
return self.model.bos_token_id
@property
def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_token_id
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
return self.model.encode(s, add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
return self.model.decode(t, offset)
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Args:
s (str): prompts
Returns:
list[int]: token ids
"""
return self.model(s)
def indexes_containing_token(self, token):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
encoded = self.encode(token, add_bos=False)
if len(encoded) > 1:
self.logger.warning(
f"The token {token}, its length of indexes"
f" {encoded} is over than 1. Currently, it can not be"
" used as stop words"
)
return []
return self.model.indexes_containing_token(token)

@ -109,11 +109,11 @@ class Task:
except Exception as error: except Exception as error:
logger.error(f"[ERROR][Task] {error}") logger.error(f"[ERROR][Task] {error}")
def run(self): def run(self, task: str, *args, **kwargs):
self.execute() self.execute(task, *args, **kwargs)
def __call__(self): def __call__(self, task: str, *args, **kwargs):
self.execute() self.execute(task, *args, **kwargs)
def handle_scheduled_task(self): def handle_scheduled_task(self):
""" """
@ -206,3 +206,5 @@ class Task:
logger.error( logger.error(
f"[ERROR][Task][check_dependency_completion] {error}" f"[ERROR][Task][check_dependency_completion] {error}"
) )

@ -0,0 +1,130 @@
import logging
from typing import List, Optional
logger_initialized = {}
def get_logger(
name: str,
log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = "w",
):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified, a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
# use logger in mmengine if exists.
try:
from mmengine.logging import MMLogger
if MMLogger.check_instance_created(name):
logger = MMLogger.get_instance(name)
else:
logger = MMLogger.get_instance(
name,
logger_name=name,
log_file=log_file,
log_level=log_level,
file_mode=file_mode,
)
return logger
except Exception:
pass
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
def filter_suffix(
response: str, suffixes: Optional[List[str]] = None
) -> str:
"""Filter response with suffixes.
Args:
response (str): generated response by LLMs.
suffixes (str): a list of suffixes to be deleted.
Return:
str: a clean response.
"""
if suffixes is None:
return response
for item in suffixes:
if response.endswith(item):
response = response[: len(response) - len(item)]
return response
# TODO remove stop_word_offsets stuff and make it clean
def _stop_words(stop_words: List[str], tokenizer: object):
"""return list of stop-words to numpy.ndarray."""
import numpy as np
if stop_words is None:
return None
assert isinstance(stop_words, List) and all(
isinstance(elem, str) for elem in stop_words
), f"stop_words must be a list but got {type(stop_words)}"
stop_indexes = []
for stop_word in stop_words:
stop_indexes += tokenizer.indexes_containing_token(stop_word)
assert isinstance(stop_indexes, List) and all(
isinstance(elem, int) for elem in stop_indexes
), "invalid stop_words"
# each id in stop_indexes represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_indexes
stop_word_offsets = range(1, len(stop_indexes) + 1)
stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(
np.int32
)
return stop_words

@ -0,0 +1,43 @@
from unittest.mock import patch
from swarms.models import TimmModel
import torch
def test_timm_model_init():
with patch("swarms.models.timm.list_models") as mock_list_models:
model_name = "resnet18"
pretrained = True
in_chans = 3
timm_model = TimmModel(model_name, pretrained, in_chans)
mock_list_models.assert_called_once()
assert timm_model.model_name == model_name
assert timm_model.pretrained == pretrained
assert timm_model.in_chans == in_chans
assert timm_model.models == mock_list_models.return_value
def test_timm_model_call():
with patch(
"swarms.models.timm.create_model"
) as mock_create_model:
model_name = "resnet18"
pretrained = True
in_chans = 3
timm_model = TimmModel(model_name, pretrained, in_chans)
task = torch.rand(1, in_chans, 224, 224)
result = timm_model(task)
mock_create_model.assert_called_once_with(
model_name, pretrained=pretrained, in_chans=in_chans
)
assert result == mock_create_model.return_value(task)
def test_timm_model_list_models():
with patch("swarms.models.timm.list_models") as mock_list_models:
model_name = "resnet18"
pretrained = True
in_chans = 3
timm_model = TimmModel(model_name, pretrained, in_chans)
result = timm_model.list_models()
mock_list_models.assert_called_once()
assert result == mock_list_models.return_value

@ -0,0 +1,35 @@
from unittest.mock import patch
from swarms.models.ultralytics_model import Ultralytics
def test_ultralytics_init():
with patch("swarms.models.YOLO") as mock_yolo:
model_name = "yolov5s"
ultralytics = Ultralytics(model_name)
mock_yolo.assert_called_once_with(model_name)
assert ultralytics.model_name == model_name
assert ultralytics.model == mock_yolo.return_value
def test_ultralytics_call():
with patch("swarms.models.YOLO") as mock_yolo:
model_name = "yolov5s"
ultralytics = Ultralytics(model_name)
task = "detect"
args = (1, 2, 3)
kwargs = {"a": "A", "b": "B"}
result = ultralytics(task, *args, **kwargs)
mock_yolo.return_value.assert_called_once_with(
task, *args, **kwargs
)
assert result == mock_yolo.return_value.return_value
def test_ultralytics_list_models():
with patch("swarms.models.YOLO") as mock_yolo:
model_name = "yolov5s"
ultralytics = Ultralytics(model_name)
result = ultralytics.list_models()
mock_yolo.list_models.assert_called_once()
assert result == mock_yolo.list_models.return_value
Loading…
Cancel
Save