diff --git a/docs/swarms/structs/recursiveworkflow.md b/docs/swarms/structs/recursiveworkflow.md index 5459c7cf..40c31478 100644 --- a/docs/swarms/structs/recursiveworkflow.md +++ b/docs/swarms/structs/recursiveworkflow.md @@ -20,52 +20,4 @@ workflow.add(task) workflow.run() ``` -Returns: None - -#### Source Code: - -```python -class RecursiveWorkflow(BaseStructure): - def __init__(self, stop_token: str = ""): - """ - Args: - stop_token (str, optional): The token that indicates when to stop the workflow. Default is "". - The stop_token indicates the value at which the current workflow is finished. - """ - self.stop_token = stop_token - self.tasks = [] - - assert ( - self.stop_token is not None - ), "stop_token cannot be None" - - def add(self, task: Task, tasks: List[Task] = None): - """Adds a task to the workflow. - Args: - task (Task): The task to be added. - tasks (List[Task], optional): List of tasks to be executed. - """ - try: - if tasks: - for task in tasks: - self.tasks.append(task) - else: - self.tasks.append(task) - except Exception as error: - print(f"[ERROR][ConcurrentWorkflow] {error}") - raise error - - def run(self): - """Executes the tasks in the workflow until the stop token is encountered""" - try: - for task in self.tasks: - while True: - result = task.execute() - if self.stop_token in result: - break - except Exception as error: - print(f"[ERROR][RecursiveWorkflow] {error}") - raise error -``` - In summary, the `RecursiveWorkflow` class is designed to automate tasks by adding and executing these tasks recursively until a stopping condition is reached. This can be achieved by utilizing the `add` and `run` methods provided. A general format for adding and utilizing the `RecursiveWorkflow` class has been provided under the "Examples" section. If you require any further information, view other sections, like Args and Source Code for specifics on using the class effectively. diff --git a/docs/swarms/structs/task.md b/docs/swarms/structs/task.md index 4a4080c0..7e829b66 100644 --- a/docs/swarms/structs/task.md +++ b/docs/swarms/structs/task.md @@ -11,8 +11,8 @@ from swarms.structs import Task, Agent from swarms.models import OpenAIChat agent = Agent(llm=OpenAIChat(openai_api_key=""), max_loops=1, dashboard=False) -task = Task(description="What's the weather in miami", agent=agent) -task.execute() +task = Task(agent=agent) +task.execute("What's the weather in miami") print(task.result) # Example 2: Adding a dependency and setting priority diff --git a/playground/models/tts_speech.py b/playground/models/tts_speech.py index f8ce3470..ca9d14f7 100644 --- a/playground/models/tts_speech.py +++ b/playground/models/tts_speech.py @@ -1,10 +1,14 @@ from swarms import OpenAITTS +import os +from dotenv import load_dotenv + +load_dotenv() tts = OpenAITTS( model_name="tts-1-1106", voice="onyx", - openai_api_key="YOUR_API_KEY", + openai_api_key=os.getenv("OPENAI_API_KEY") ) -out = tts.run_and_save("pliny is a girl and a chicken") +out = tts.run_and_save("Dammmmmm those tacos were good") print(out) diff --git a/pyproject.toml b/pyproject.toml index 9a7135ef..79ac36fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "3.6.8" +version = "3.7.3" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -31,7 +31,7 @@ asyncio = "3.4.3" einops = "0.7.0" google-generativeai = "0.3.1" langchain-experimental = "0.0.10" -playwright = "1.34.0" +tensorflow = "*" weaviate-client = "3.25.3" opencv-python-headless = "4.8.1.78" faiss-cpu = "1.7.4" @@ -44,14 +44,12 @@ PyPDF2 = "3.0.1" accelerate = "*" sentencepiece = "0.1.98" wget = "3.2" -tensorflow = "2.14.0" httpx = "0.24.1" tiktoken = "0.4.0" safetensors = "0.3.3" attrs = "22.2.0" ggl = "1.1.0" ratelimit = "2.2.1" -beautifulsoup4 = "4.11.2" cohere = "4.24" huggingface-hub = "*" pydantic = "1.10.12" @@ -74,6 +72,7 @@ peft = "*" psutil = "*" ultralytics = "*" timm = "*" +supervision = "*" diff --git a/requirements.txt b/requirements.txt index 3c14d025..d7befb85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ requests_mock PyPDF2==3.0.1 accelerate==0.22.0 chromadb==0.4.14 -tensorflow==2.14.0 +tensorflow optimum tiktoken==0.4.0 tabulate==0.9.0 @@ -60,4 +60,5 @@ mkdocs-glightbox pre-commit==3.2.2 peft psutil -ultralytics \ No newline at end of file +ultralytics +supervision \ No newline at end of file diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 0b1992eb..364d1d7f 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -41,7 +41,9 @@ from swarms.models.gemini import Gemini # noqa: E402 from swarms.models.gigabind import Gigabind # noqa: E402 from swarms.models.zeroscope import ZeroscopeTTV # noqa: E402 from swarms.models.timm import TimmModel # noqa: E402 -from swarms.models.ultralytics_model import UltralyticsModel # noqa: E402 +from swarms.models.ultralytics_model import ( + UltralyticsModel, +) # noqa: E402 # from swarms.models.dalle3 import Dalle3 @@ -51,6 +53,9 @@ from swarms.models.ultralytics_model import UltralyticsModel # noqa: E402 # from swarms.models.cog_agent import CogAgent # noqa: E402 +################# Tokenizers + + ############## Types from swarms.models.types import ( TextModality, diff --git a/swarms/models/odin.py b/swarms/models/odin.py new file mode 100644 index 00000000..74e0c556 --- /dev/null +++ b/swarms/models/odin.py @@ -0,0 +1,90 @@ +import supervision as sv +from ultraanalytics import YOLO +from tqdm import tqdm +from swarms.models.base_llm import AbstractLLM + +class Odin(AbstractLLM): + """ + Odin class represents an object detection and tracking model. + + Args: + source_weights_path (str): Path to the weights file for the object detection model. + source_video_path (str): Path to the source video file. + target_video_path (str): Path to save the output video file. + confidence_threshold (float): Confidence threshold for object detection. + iou_threshold (float): Intersection over Union (IoU) threshold for object detection. + + Attributes: + source_weights_path (str): Path to the weights file for the object detection model. + source_video_path (str): Path to the source video file. + target_video_path (str): Path to save the output video file. + confidence_threshold (float): Confidence threshold for object detection. + iou_threshold (float): Intersection over Union (IoU) threshold for object detection. + """ + + def __init__( + self, + source_weights_path: str = None, + target_video_path: str = None, + confidence_threshold: float = 0.3, + iou_threshold: float = 0.7, + ): + super(Odin, self).__init__() + self.source_weights_path = source_weights_path + self.target_video_path = target_video_path + self.confidence_threshold = confidence_threshold + self.iou_threshold = iou_threshold + + def run(self, video_path: str, *args, **kwargs): + """ + Runs the object detection and tracking algorithm on the specified video. + + Args: + video_path (str): The path to the input video file. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + bool: True if the video was processed successfully, False otherwise. + """ + model = YOLO(self.source_weights_path) + + tracker = sv.ByteTrack() + box_annotator = sv.BoxAnnotator() + frame_generator = sv.get_video_frames_generator( + source_path=self.source_video_path + ) + video_info = sv.VideoInfo.from_video_path( + video_path=video_path + ) + + with sv.VideoSink( + target_path=self.target_video_path, video_info=video_info + ) as sink: + for frame in tqdm( + frame_generator, total=video_info.total_frames + ): + results = model( + frame, + verbose=True, + conf=self.confidence_threshold, + iou=self.iou_threshold, + )[0] + detections = sv.Detections.from_ultranalytics(results) + detections = tracker.update_with_detections( + detections + ) + + labels = [ + f"#{tracker_id} {model.model.names[class_id]}" + for _, _, _, class_id, tracker_id in detections + ] + + annotated_frame = box_annotator.annotate( + scene=frame.copy(), + detections=detections, + labels=labels, + ) + + result = sink.write_frame(frame=annotated_frame) + return result diff --git a/swarms/models/r_tokenizers.py b/swarms/models/r_tokenizers.py new file mode 100644 index 00000000..cf8253fc --- /dev/null +++ b/swarms/models/r_tokenizers.py @@ -0,0 +1,422 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import os.path as osp +from collections import deque +from typing import List, Optional, Sequence, Union + +import torch + +from swarms.utils.get_logger import get_logger + + +class SentencePieceTokenizer: + """Tokenizer of sentencepiece. + + Args: + model_file (str): the path of the tokenizer model + """ + + def __init__(self, model_file: str): + from sentencepiece import SentencePieceProcessor + + self.model = SentencePieceProcessor(model_file=model_file) + self._prefix_space_tokens = None + # for stop words + self._maybe_decode_bytes: bool = None + # TODO maybe lack a constant.py + self._indexes_tokens_deque = deque(maxlen=10) + self.max_indexes_num = 5 + self.logger = get_logger("lmdeploy") + + @property + def vocab_size(self): + """vocabulary size.""" + return self.model.vocab_size() + + @property + def bos_token_id(self): + """begine of the sentence token id.""" + return self.model.bos_id() + + @property + def eos_token_id(self): + """end of the sentence token id.""" + return self.model.eos_id() + + @property + def prefix_space_tokens(self): + """tokens without prefix space.""" + if self._prefix_space_tokens is None: + vocab = self.model.IdToPiece(list(range(self.vocab_size))) + self._prefix_space_tokens = { + i + for i, tok in enumerate(vocab) + if tok.startswith("▁") + } + return self._prefix_space_tokens + + def _maybe_add_prefix_space(self, tokens, decoded): + """maybe add prefix space for incremental decoding.""" + if ( + len(tokens) + and not decoded.startswith(" ") + and tokens[0] in self.prefix_space_tokens + ): + return " " + decoded + else: + return decoded + + def indexes_containing_token(self, token: str): + """Return all the possible indexes, whose decoding output may contain + the input token.""" + # traversing vocab is time consuming, can not be accelerated with + # multi threads (computation) or multi process (can't pickle tokenizer) + # so, we maintain latest 10 stop words and return directly if matched + for _token, _indexes in self._indexes_tokens_deque: + if token == _token: + return _indexes + if token == " ": # ' ' is special + token = "▁" + vocab = self.model.IdToPiece(list(range(self.vocab_size))) + indexes = [i for i, voc in enumerate(vocab) if token in voc] + if len(indexes) > self.max_indexes_num: + indexes = self.encode(token, add_bos=False)[-1:] + self.logger.warning( + f"There are too many(>{self.max_indexes_num})" + f" possible indexes may decoding {token}, we will use" + f" {indexes} only" + ) + self._indexes_tokens_deque.append((token, indexes)) + return indexes + + def encode(self, s: str, add_bos: bool = True, **kwargs): + """Tokenize a prompt. + + Args: + s (str): a prompt + Returns: + list[int]: token ids + """ + return self.model.Encode(s, add_bos=add_bos, **kwargs) + + def decode(self, t: Sequence[int], offset: Optional[int] = None): + """De-tokenize. + + Args: + t (List[int]): a list of token ids + offset (int): for incrementally decoding. Default to None, which + means not applied. + Returns: + str: text of decoding tokens + """ + if isinstance(t, torch.Tensor): + t = t.tolist() + t = t[offset:] + out_string = self.model.Decode(t) + if offset: + out_string = self._maybe_add_prefix_space(t, out_string) + return out_string + + def __call__(self, s: Union[str, Sequence[str]]): + """Tokenize prompts. + + Args: + s (str): prompts + Returns: + list[int]: token ids + """ + import addict + + add_bos = False + add_eos = False + + input_ids = self.model.Encode( + s, add_bos=add_bos, add_eos=add_eos + ) + return addict.Addict(input_ids=input_ids) + + +class HuggingFaceTokenizer: + """Tokenizer of sentencepiece. + + Args: + model_dir (str): the directory of the tokenizer model + """ + + def __init__(self, model_dir: str): + from transformers import AutoTokenizer + + model_file = osp.join(model_dir, "tokenizer.model") + backend_tokenizer_file = osp.join(model_dir, "tokenizer.json") + model_file_exists = osp.exists(model_file) + self.logger = get_logger("lmdeploy") + if ( + not osp.exists(backend_tokenizer_file) + and model_file_exists + ): + self.logger.warning( + "Can not find tokenizer.json. " + "It may take long time to initialize the tokenizer." + ) + self.model = AutoTokenizer.from_pretrained( + model_dir, trust_remote_code=True + ) + self._prefix_space_tokens = None + # save tokenizer.json to reuse + if ( + not osp.exists(backend_tokenizer_file) + and model_file_exists + ): + if hasattr(self.model, "backend_tokenizer"): + if os.access(model_dir, os.W_OK): + self.model.backend_tokenizer.save( + backend_tokenizer_file + ) + + if self.model.eos_token_id is None: + generation_config_file = osp.join( + model_dir, "generation_config.json" + ) + if osp.exists(generation_config_file): + with open(generation_config_file, "r") as f: + cfg = json.load(f) + self.model.eos_token_id = cfg["eos_token_id"] + elif hasattr(self.model, "eod_id"): # Qwen remote + self.model.eos_token_id = self.model.eod_id + + # for stop words + self._maybe_decode_bytes: bool = None + # TODO maybe lack a constant.py + self._indexes_tokens_deque = deque(maxlen=10) + self.max_indexes_num = 5 + self.token2id = {} + + @property + def vocab_size(self): + """vocabulary size.""" + return self.model.vocab_size + + @property + def bos_token_id(self): + """begine of the sentence token id.""" + return self.model.bos_token_id + + @property + def eos_token_id(self): + """end of the sentence token id.""" + return self.model.eos_token_id + + @property + def prefix_space_tokens(self): + """tokens without prefix space.""" + if self._prefix_space_tokens is None: + vocab = self.model.convert_ids_to_tokens( + list(range(self.vocab_size)) + ) + self._prefix_space_tokens = { + i + for i, tok in enumerate(vocab) + if tok.startswith( + "▁" if isinstance(tok, str) else b" " + ) + } + return self._prefix_space_tokens + + def _maybe_add_prefix_space( + self, tokens: List[int], decoded: str + ): + """maybe add prefix space for incremental decoding.""" + if ( + len(tokens) + and not decoded.startswith(" ") + and tokens[0] in self.prefix_space_tokens + ): + return " " + decoded + else: + return decoded + + @property + def maybe_decode_bytes(self): + """Check if self.model.convert_ids_to_tokens return not a str value.""" + if self._maybe_decode_bytes is None: + self._maybe_decode_bytes = False + vocab = self.model.convert_ids_to_tokens( + list(range(self.vocab_size)) + ) + for tok in vocab: + if not isinstance(tok, str): + self._maybe_decode_bytes = True + break + return self._maybe_decode_bytes + + def indexes_containing_token(self, token: str): + """Return all the possible indexes, whose decoding output may contain + the input token.""" + # traversing vocab is time consuming, can not be accelerated with + # multi threads (computation) or multi process (can't pickle tokenizer) + # so, we maintain latest 10 stop words and return directly if matched + for _token, _indexes in self._indexes_tokens_deque: + if token == _token: + return _indexes + + if self.token2id == {}: + # decode is slower than convert_ids_to_tokens + if self.maybe_decode_bytes: + self.token2id = { + self.model.decode(i): i + for i in range(self.vocab_size) + } + else: + self.token2id = { + self.model.convert_ids_to_tokens(i): i + for i in range(self.vocab_size) + } + if token == " ": # ' ' is special + token = "▁" + indexes = [ + i + for _token, i in self.token2id.items() + if token in _token + ] + if len(indexes) > self.max_indexes_num: + indexes = self.encode(token, add_bos=False)[-1:] + self.logger.warning( + f"There are too many(>{self.max_indexes_num})" + f" possible indexes may decoding {token}, we will use" + f" {indexes} only" + ) + self._indexes_tokens_deque.append((token, indexes)) + return indexes + + def encode(self, s: str, add_bos: bool = True, **kwargs): + """Tokenize a prompt. + + Args: + s (str): a prompt + Returns: + list[int]: token ids + """ + encoded = self.model.encode(s, **kwargs) + if not add_bos: + # in the middle of a session + if len(encoded) and encoded[0] == self.bos_token_id: + encoded = encoded[1:] + return encoded + + def decode(self, t: Sequence[int], offset: Optional[int] = None): + """De-tokenize. + + Args: + t (List[int]): a list of token ids + offset (int): for incrementally decoding. Default to None, which + means not applied. + Returns: + str: text of decoding tokens + """ + skip_special_tokens = True + t = t[offset:] + out_string = self.model.decode( + t, skip_special_tokens=skip_special_tokens + ) + if offset: + out_string = self._maybe_add_prefix_space(t, out_string) + return out_string + + def __call__(self, s: Union[str, Sequence[str]]): + """Tokenize prompts. + + Args: + s (str): prompts + Returns: + list[int]: token ids + """ + add_special_tokens = False + return self.model(s, add_special_tokens=add_special_tokens) + + +class Tokenizer: + """Tokenize prompts or de-tokenize tokens into texts. + + Args: + model_file (str): the path of the tokenizer model + """ + + def __init__(self, model_file: str): + if model_file.endswith(".model"): + model_folder = osp.split(model_file)[0] + else: + model_folder = model_file + model_file = osp.join(model_folder, "tokenizer.model") + tokenizer_config_file = osp.join( + model_folder, "tokenizer_config.json" + ) + + model_file_exists = osp.exists(model_file) + config_exists = osp.exists(tokenizer_config_file) + use_hf_model = config_exists or not model_file_exists + self.logger = get_logger("lmdeploy") + if not use_hf_model: + self.model = SentencePieceTokenizer(model_file) + else: + self.model = HuggingFaceTokenizer(model_folder) + + @property + def vocab_size(self): + """vocabulary size.""" + return self.model.vocab_size + + @property + def bos_token_id(self): + """begine of the sentence token id.""" + return self.model.bos_token_id + + @property + def eos_token_id(self): + """end of the sentence token id.""" + return self.model.eos_token_id + + def encode(self, s: str, add_bos: bool = True, **kwargs): + """Tokenize a prompt. + + Args: + s (str): a prompt + Returns: + list[int]: token ids + """ + return self.model.encode(s, add_bos, **kwargs) + + def decode(self, t: Sequence[int], offset: Optional[int] = None): + """De-tokenize. + + Args: + t (List[int]): a list of token ids + offset (int): for incrementally decoding. Default to None, which + means not applied. + Returns: + str: text of decoding tokens + """ + return self.model.decode(t, offset) + + def __call__(self, s: Union[str, Sequence[str]]): + """Tokenize prompts. + + Args: + s (str): prompts + Returns: + list[int]: token ids + """ + return self.model(s) + + def indexes_containing_token(self, token): + """Return all the possible indexes, whose decoding output may contain + the input token.""" + encoded = self.encode(token, add_bos=False) + if len(encoded) > 1: + self.logger.warning( + f"The token {token}, its length of indexes" + f" {encoded} is over than 1. Currently, it can not be" + " used as stop words" + ) + return [] + return self.model.indexes_containing_token(token) diff --git a/swarms/structs/task.py b/swarms/structs/task.py index 699b7313..f2bf1bfc 100644 --- a/swarms/structs/task.py +++ b/swarms/structs/task.py @@ -109,11 +109,11 @@ class Task: except Exception as error: logger.error(f"[ERROR][Task] {error}") - def run(self): - self.execute() + def run(self, task: str, *args, **kwargs): + self.execute(task, *args, **kwargs) - def __call__(self): - self.execute() + def __call__(self, task: str, *args, **kwargs): + self.execute(task, *args, **kwargs) def handle_scheduled_task(self): """ @@ -206,3 +206,5 @@ class Task: logger.error( f"[ERROR][Task][check_dependency_completion] {error}" ) + + diff --git a/swarms/utils/get_logger.py b/swarms/utils/get_logger.py new file mode 100644 index 00000000..54fc8056 --- /dev/null +++ b/swarms/utils/get_logger.py @@ -0,0 +1,130 @@ +import logging +from typing import List, Optional + +logger_initialized = {} + + +def get_logger( + name: str, + log_file: Optional[str] = None, + log_level: int = logging.INFO, + file_mode: str = "w", +): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified, a FileHandler will also be added. + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + Returns: + logging.Logger: The expected logger. + """ + # use logger in mmengine if exists. + try: + from mmengine.logging import MMLogger + + if MMLogger.check_instance_created(name): + logger = MMLogger.get_instance(name) + else: + logger = MMLogger.get_instance( + name, + logger_name=name, + log_file=log_file, + log_level=log_level, + file_mode=file_mode, + ) + return logger + + except Exception: + pass + + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + logger.setLevel(log_level) + logger_initialized[name] = True + + return logger + + +def filter_suffix( + response: str, suffixes: Optional[List[str]] = None +) -> str: + """Filter response with suffixes. + + Args: + response (str): generated response by LLMs. + suffixes (str): a list of suffixes to be deleted. + + Return: + str: a clean response. + """ + if suffixes is None: + return response + for item in suffixes: + if response.endswith(item): + response = response[: len(response) - len(item)] + return response + + +# TODO remove stop_word_offsets stuff and make it clean +def _stop_words(stop_words: List[str], tokenizer: object): + """return list of stop-words to numpy.ndarray.""" + import numpy as np + + if stop_words is None: + return None + assert isinstance(stop_words, List) and all( + isinstance(elem, str) for elem in stop_words + ), f"stop_words must be a list but got {type(stop_words)}" + stop_indexes = [] + for stop_word in stop_words: + stop_indexes += tokenizer.indexes_containing_token(stop_word) + assert isinstance(stop_indexes, List) and all( + isinstance(elem, int) for elem in stop_indexes + ), "invalid stop_words" + # each id in stop_indexes represents a stop word + # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for + # detailed explanation about fastertransformer's stop_indexes + stop_word_offsets = range(1, len(stop_indexes) + 1) + stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype( + np.int32 + ) + return stop_words diff --git a/tests/models/test_timm.py b/tests/models/test_timm.py new file mode 100644 index 00000000..fae5f704 --- /dev/null +++ b/tests/models/test_timm.py @@ -0,0 +1,43 @@ +from unittest.mock import patch +from swarms.models import TimmModel +import torch + + +def test_timm_model_init(): + with patch("swarms.models.timm.list_models") as mock_list_models: + model_name = "resnet18" + pretrained = True + in_chans = 3 + timm_model = TimmModel(model_name, pretrained, in_chans) + mock_list_models.assert_called_once() + assert timm_model.model_name == model_name + assert timm_model.pretrained == pretrained + assert timm_model.in_chans == in_chans + assert timm_model.models == mock_list_models.return_value + + +def test_timm_model_call(): + with patch( + "swarms.models.timm.create_model" + ) as mock_create_model: + model_name = "resnet18" + pretrained = True + in_chans = 3 + timm_model = TimmModel(model_name, pretrained, in_chans) + task = torch.rand(1, in_chans, 224, 224) + result = timm_model(task) + mock_create_model.assert_called_once_with( + model_name, pretrained=pretrained, in_chans=in_chans + ) + assert result == mock_create_model.return_value(task) + + +def test_timm_model_list_models(): + with patch("swarms.models.timm.list_models") as mock_list_models: + model_name = "resnet18" + pretrained = True + in_chans = 3 + timm_model = TimmModel(model_name, pretrained, in_chans) + result = timm_model.list_models() + mock_list_models.assert_called_once() + assert result == mock_list_models.return_value diff --git a/tests/models/test_ultralytics.py b/tests/models/test_ultralytics.py new file mode 100644 index 00000000..8a1fed00 --- /dev/null +++ b/tests/models/test_ultralytics.py @@ -0,0 +1,35 @@ +from unittest.mock import patch +from swarms.models.ultralytics_model import Ultralytics + + +def test_ultralytics_init(): + with patch("swarms.models.YOLO") as mock_yolo: + model_name = "yolov5s" + ultralytics = Ultralytics(model_name) + mock_yolo.assert_called_once_with(model_name) + assert ultralytics.model_name == model_name + assert ultralytics.model == mock_yolo.return_value + + +def test_ultralytics_call(): + with patch("swarms.models.YOLO") as mock_yolo: + model_name = "yolov5s" + ultralytics = Ultralytics(model_name) + task = "detect" + args = (1, 2, 3) + kwargs = {"a": "A", "b": "B"} + result = ultralytics(task, *args, **kwargs) + mock_yolo.return_value.assert_called_once_with( + task, *args, **kwargs + ) + assert result == mock_yolo.return_value.return_value + + + +def test_ultralytics_list_models(): + with patch("swarms.models.YOLO") as mock_yolo: + model_name = "yolov5s" + ultralytics = Ultralytics(model_name) + result = ultralytics.list_models() + mock_yolo.list_models.assert_called_once() + assert result == mock_yolo.list_models.return_value \ No newline at end of file