parent
d70398806d
commit
e29ec9c943
@ -1,10 +1,14 @@
|
||||
from swarms import OpenAITTS
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
tts = OpenAITTS(
|
||||
model_name="tts-1-1106",
|
||||
voice="onyx",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
out = tts.run_and_save("pliny is a girl and a chicken")
|
||||
out = tts.run_and_save("Dammmmmm those tacos were good")
|
||||
print(out)
|
||||
|
@ -0,0 +1,90 @@
|
||||
import supervision as sv
|
||||
from ultraanalytics import YOLO
|
||||
from tqdm import tqdm
|
||||
from swarms.models.base_llm import AbstractLLM
|
||||
|
||||
class Odin(AbstractLLM):
|
||||
"""
|
||||
Odin class represents an object detection and tracking model.
|
||||
|
||||
Args:
|
||||
source_weights_path (str): Path to the weights file for the object detection model.
|
||||
source_video_path (str): Path to the source video file.
|
||||
target_video_path (str): Path to save the output video file.
|
||||
confidence_threshold (float): Confidence threshold for object detection.
|
||||
iou_threshold (float): Intersection over Union (IoU) threshold for object detection.
|
||||
|
||||
Attributes:
|
||||
source_weights_path (str): Path to the weights file for the object detection model.
|
||||
source_video_path (str): Path to the source video file.
|
||||
target_video_path (str): Path to save the output video file.
|
||||
confidence_threshold (float): Confidence threshold for object detection.
|
||||
iou_threshold (float): Intersection over Union (IoU) threshold for object detection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_weights_path: str = None,
|
||||
target_video_path: str = None,
|
||||
confidence_threshold: float = 0.3,
|
||||
iou_threshold: float = 0.7,
|
||||
):
|
||||
super(Odin, self).__init__()
|
||||
self.source_weights_path = source_weights_path
|
||||
self.target_video_path = target_video_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
|
||||
def run(self, video_path: str, *args, **kwargs):
|
||||
"""
|
||||
Runs the object detection and tracking algorithm on the specified video.
|
||||
|
||||
Args:
|
||||
video_path (str): The path to the input video file.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
bool: True if the video was processed successfully, False otherwise.
|
||||
"""
|
||||
model = YOLO(self.source_weights_path)
|
||||
|
||||
tracker = sv.ByteTrack()
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
frame_generator = sv.get_video_frames_generator(
|
||||
source_path=self.source_video_path
|
||||
)
|
||||
video_info = sv.VideoInfo.from_video_path(
|
||||
video_path=video_path
|
||||
)
|
||||
|
||||
with sv.VideoSink(
|
||||
target_path=self.target_video_path, video_info=video_info
|
||||
) as sink:
|
||||
for frame in tqdm(
|
||||
frame_generator, total=video_info.total_frames
|
||||
):
|
||||
results = model(
|
||||
frame,
|
||||
verbose=True,
|
||||
conf=self.confidence_threshold,
|
||||
iou=self.iou_threshold,
|
||||
)[0]
|
||||
detections = sv.Detections.from_ultranalytics(results)
|
||||
detections = tracker.update_with_detections(
|
||||
detections
|
||||
)
|
||||
|
||||
labels = [
|
||||
f"#{tracker_id} {model.model.names[class_id]}"
|
||||
for _, _, _, class_id, tracker_id in detections
|
||||
]
|
||||
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=frame.copy(),
|
||||
detections=detections,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
result = sink.write_frame(frame=annotated_frame)
|
||||
return result
|
@ -0,0 +1,422 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import deque
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from swarms.utils.get_logger import get_logger
|
||||
|
||||
|
||||
class SentencePieceTokenizer:
|
||||
"""Tokenizer of sentencepiece.
|
||||
|
||||
Args:
|
||||
model_file (str): the path of the tokenizer model
|
||||
"""
|
||||
|
||||
def __init__(self, model_file: str):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
self.model = SentencePieceProcessor(model_file=model_file)
|
||||
self._prefix_space_tokens = None
|
||||
# for stop words
|
||||
self._maybe_decode_bytes: bool = None
|
||||
# TODO maybe lack a constant.py
|
||||
self._indexes_tokens_deque = deque(maxlen=10)
|
||||
self.max_indexes_num = 5
|
||||
self.logger = get_logger("lmdeploy")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""vocabulary size."""
|
||||
return self.model.vocab_size()
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
"""begine of the sentence token id."""
|
||||
return self.model.bos_id()
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
"""end of the sentence token id."""
|
||||
return self.model.eos_id()
|
||||
|
||||
@property
|
||||
def prefix_space_tokens(self):
|
||||
"""tokens without prefix space."""
|
||||
if self._prefix_space_tokens is None:
|
||||
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
|
||||
self._prefix_space_tokens = {
|
||||
i
|
||||
for i, tok in enumerate(vocab)
|
||||
if tok.startswith("▁")
|
||||
}
|
||||
return self._prefix_space_tokens
|
||||
|
||||
def _maybe_add_prefix_space(self, tokens, decoded):
|
||||
"""maybe add prefix space for incremental decoding."""
|
||||
if (
|
||||
len(tokens)
|
||||
and not decoded.startswith(" ")
|
||||
and tokens[0] in self.prefix_space_tokens
|
||||
):
|
||||
return " " + decoded
|
||||
else:
|
||||
return decoded
|
||||
|
||||
def indexes_containing_token(self, token: str):
|
||||
"""Return all the possible indexes, whose decoding output may contain
|
||||
the input token."""
|
||||
# traversing vocab is time consuming, can not be accelerated with
|
||||
# multi threads (computation) or multi process (can't pickle tokenizer)
|
||||
# so, we maintain latest 10 stop words and return directly if matched
|
||||
for _token, _indexes in self._indexes_tokens_deque:
|
||||
if token == _token:
|
||||
return _indexes
|
||||
if token == " ": # ' ' is special
|
||||
token = "▁"
|
||||
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
|
||||
indexes = [i for i, voc in enumerate(vocab) if token in voc]
|
||||
if len(indexes) > self.max_indexes_num:
|
||||
indexes = self.encode(token, add_bos=False)[-1:]
|
||||
self.logger.warning(
|
||||
f"There are too many(>{self.max_indexes_num})"
|
||||
f" possible indexes may decoding {token}, we will use"
|
||||
f" {indexes} only"
|
||||
)
|
||||
self._indexes_tokens_deque.append((token, indexes))
|
||||
return indexes
|
||||
|
||||
def encode(self, s: str, add_bos: bool = True, **kwargs):
|
||||
"""Tokenize a prompt.
|
||||
|
||||
Args:
|
||||
s (str): a prompt
|
||||
Returns:
|
||||
list[int]: token ids
|
||||
"""
|
||||
return self.model.Encode(s, add_bos=add_bos, **kwargs)
|
||||
|
||||
def decode(self, t: Sequence[int], offset: Optional[int] = None):
|
||||
"""De-tokenize.
|
||||
|
||||
Args:
|
||||
t (List[int]): a list of token ids
|
||||
offset (int): for incrementally decoding. Default to None, which
|
||||
means not applied.
|
||||
Returns:
|
||||
str: text of decoding tokens
|
||||
"""
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.tolist()
|
||||
t = t[offset:]
|
||||
out_string = self.model.Decode(t)
|
||||
if offset:
|
||||
out_string = self._maybe_add_prefix_space(t, out_string)
|
||||
return out_string
|
||||
|
||||
def __call__(self, s: Union[str, Sequence[str]]):
|
||||
"""Tokenize prompts.
|
||||
|
||||
Args:
|
||||
s (str): prompts
|
||||
Returns:
|
||||
list[int]: token ids
|
||||
"""
|
||||
import addict
|
||||
|
||||
add_bos = False
|
||||
add_eos = False
|
||||
|
||||
input_ids = self.model.Encode(
|
||||
s, add_bos=add_bos, add_eos=add_eos
|
||||
)
|
||||
return addict.Addict(input_ids=input_ids)
|
||||
|
||||
|
||||
class HuggingFaceTokenizer:
|
||||
"""Tokenizer of sentencepiece.
|
||||
|
||||
Args:
|
||||
model_dir (str): the directory of the tokenizer model
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: str):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_file = osp.join(model_dir, "tokenizer.model")
|
||||
backend_tokenizer_file = osp.join(model_dir, "tokenizer.json")
|
||||
model_file_exists = osp.exists(model_file)
|
||||
self.logger = get_logger("lmdeploy")
|
||||
if (
|
||||
not osp.exists(backend_tokenizer_file)
|
||||
and model_file_exists
|
||||
):
|
||||
self.logger.warning(
|
||||
"Can not find tokenizer.json. "
|
||||
"It may take long time to initialize the tokenizer."
|
||||
)
|
||||
self.model = AutoTokenizer.from_pretrained(
|
||||
model_dir, trust_remote_code=True
|
||||
)
|
||||
self._prefix_space_tokens = None
|
||||
# save tokenizer.json to reuse
|
||||
if (
|
||||
not osp.exists(backend_tokenizer_file)
|
||||
and model_file_exists
|
||||
):
|
||||
if hasattr(self.model, "backend_tokenizer"):
|
||||
if os.access(model_dir, os.W_OK):
|
||||
self.model.backend_tokenizer.save(
|
||||
backend_tokenizer_file
|
||||
)
|
||||
|
||||
if self.model.eos_token_id is None:
|
||||
generation_config_file = osp.join(
|
||||
model_dir, "generation_config.json"
|
||||
)
|
||||
if osp.exists(generation_config_file):
|
||||
with open(generation_config_file, "r") as f:
|
||||
cfg = json.load(f)
|
||||
self.model.eos_token_id = cfg["eos_token_id"]
|
||||
elif hasattr(self.model, "eod_id"): # Qwen remote
|
||||
self.model.eos_token_id = self.model.eod_id
|
||||
|
||||
# for stop words
|
||||
self._maybe_decode_bytes: bool = None
|
||||
# TODO maybe lack a constant.py
|
||||
self._indexes_tokens_deque = deque(maxlen=10)
|
||||
self.max_indexes_num = 5
|
||||
self.token2id = {}
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""vocabulary size."""
|
||||
return self.model.vocab_size
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
"""begine of the sentence token id."""
|
||||
return self.model.bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
"""end of the sentence token id."""
|
||||
return self.model.eos_token_id
|
||||
|
||||
@property
|
||||
def prefix_space_tokens(self):
|
||||
"""tokens without prefix space."""
|
||||
if self._prefix_space_tokens is None:
|
||||
vocab = self.model.convert_ids_to_tokens(
|
||||
list(range(self.vocab_size))
|
||||
)
|
||||
self._prefix_space_tokens = {
|
||||
i
|
||||
for i, tok in enumerate(vocab)
|
||||
if tok.startswith(
|
||||
"▁" if isinstance(tok, str) else b" "
|
||||
)
|
||||
}
|
||||
return self._prefix_space_tokens
|
||||
|
||||
def _maybe_add_prefix_space(
|
||||
self, tokens: List[int], decoded: str
|
||||
):
|
||||
"""maybe add prefix space for incremental decoding."""
|
||||
if (
|
||||
len(tokens)
|
||||
and not decoded.startswith(" ")
|
||||
and tokens[0] in self.prefix_space_tokens
|
||||
):
|
||||
return " " + decoded
|
||||
else:
|
||||
return decoded
|
||||
|
||||
@property
|
||||
def maybe_decode_bytes(self):
|
||||
"""Check if self.model.convert_ids_to_tokens return not a str value."""
|
||||
if self._maybe_decode_bytes is None:
|
||||
self._maybe_decode_bytes = False
|
||||
vocab = self.model.convert_ids_to_tokens(
|
||||
list(range(self.vocab_size))
|
||||
)
|
||||
for tok in vocab:
|
||||
if not isinstance(tok, str):
|
||||
self._maybe_decode_bytes = True
|
||||
break
|
||||
return self._maybe_decode_bytes
|
||||
|
||||
def indexes_containing_token(self, token: str):
|
||||
"""Return all the possible indexes, whose decoding output may contain
|
||||
the input token."""
|
||||
# traversing vocab is time consuming, can not be accelerated with
|
||||
# multi threads (computation) or multi process (can't pickle tokenizer)
|
||||
# so, we maintain latest 10 stop words and return directly if matched
|
||||
for _token, _indexes in self._indexes_tokens_deque:
|
||||
if token == _token:
|
||||
return _indexes
|
||||
|
||||
if self.token2id == {}:
|
||||
# decode is slower than convert_ids_to_tokens
|
||||
if self.maybe_decode_bytes:
|
||||
self.token2id = {
|
||||
self.model.decode(i): i
|
||||
for i in range(self.vocab_size)
|
||||
}
|
||||
else:
|
||||
self.token2id = {
|
||||
self.model.convert_ids_to_tokens(i): i
|
||||
for i in range(self.vocab_size)
|
||||
}
|
||||
if token == " ": # ' ' is special
|
||||
token = "▁"
|
||||
indexes = [
|
||||
i
|
||||
for _token, i in self.token2id.items()
|
||||
if token in _token
|
||||
]
|
||||
if len(indexes) > self.max_indexes_num:
|
||||
indexes = self.encode(token, add_bos=False)[-1:]
|
||||
self.logger.warning(
|
||||
f"There are too many(>{self.max_indexes_num})"
|
||||
f" possible indexes may decoding {token}, we will use"
|
||||
f" {indexes} only"
|
||||
)
|
||||
self._indexes_tokens_deque.append((token, indexes))
|
||||
return indexes
|
||||
|
||||
def encode(self, s: str, add_bos: bool = True, **kwargs):
|
||||
"""Tokenize a prompt.
|
||||
|
||||
Args:
|
||||
s (str): a prompt
|
||||
Returns:
|
||||
list[int]: token ids
|
||||
"""
|
||||
encoded = self.model.encode(s, **kwargs)
|
||||
if not add_bos:
|
||||
# in the middle of a session
|
||||
if len(encoded) and encoded[0] == self.bos_token_id:
|
||||
encoded = encoded[1:]
|
||||
return encoded
|
||||
|
||||
def decode(self, t: Sequence[int], offset: Optional[int] = None):
|
||||
"""De-tokenize.
|
||||
|
||||
Args:
|
||||
t (List[int]): a list of token ids
|
||||
offset (int): for incrementally decoding. Default to None, which
|
||||
means not applied.
|
||||
Returns:
|
||||
str: text of decoding tokens
|
||||
"""
|
||||
skip_special_tokens = True
|
||||
t = t[offset:]
|
||||
out_string = self.model.decode(
|
||||
t, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
if offset:
|
||||
out_string = self._maybe_add_prefix_space(t, out_string)
|
||||
return out_string
|
||||
|
||||
def __call__(self, s: Union[str, Sequence[str]]):
|
||||
"""Tokenize prompts.
|
||||
|
||||
Args:
|
||||
s (str): prompts
|
||||
Returns:
|
||||
list[int]: token ids
|
||||
"""
|
||||
add_special_tokens = False
|
||||
return self.model(s, add_special_tokens=add_special_tokens)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Tokenize prompts or de-tokenize tokens into texts.
|
||||
|
||||
Args:
|
||||
model_file (str): the path of the tokenizer model
|
||||
"""
|
||||
|
||||
def __init__(self, model_file: str):
|
||||
if model_file.endswith(".model"):
|
||||
model_folder = osp.split(model_file)[0]
|
||||
else:
|
||||
model_folder = model_file
|
||||
model_file = osp.join(model_folder, "tokenizer.model")
|
||||
tokenizer_config_file = osp.join(
|
||||
model_folder, "tokenizer_config.json"
|
||||
)
|
||||
|
||||
model_file_exists = osp.exists(model_file)
|
||||
config_exists = osp.exists(tokenizer_config_file)
|
||||
use_hf_model = config_exists or not model_file_exists
|
||||
self.logger = get_logger("lmdeploy")
|
||||
if not use_hf_model:
|
||||
self.model = SentencePieceTokenizer(model_file)
|
||||
else:
|
||||
self.model = HuggingFaceTokenizer(model_folder)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""vocabulary size."""
|
||||
return self.model.vocab_size
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
"""begine of the sentence token id."""
|
||||
return self.model.bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
"""end of the sentence token id."""
|
||||
return self.model.eos_token_id
|
||||
|
||||
def encode(self, s: str, add_bos: bool = True, **kwargs):
|
||||
"""Tokenize a prompt.
|
||||
|
||||
Args:
|
||||
s (str): a prompt
|
||||
Returns:
|
||||
list[int]: token ids
|
||||
"""
|
||||
return self.model.encode(s, add_bos, **kwargs)
|
||||
|
||||
def decode(self, t: Sequence[int], offset: Optional[int] = None):
|
||||
"""De-tokenize.
|
||||
|
||||
Args:
|
||||
t (List[int]): a list of token ids
|
||||
offset (int): for incrementally decoding. Default to None, which
|
||||
means not applied.
|
||||
Returns:
|
||||
str: text of decoding tokens
|
||||
"""
|
||||
return self.model.decode(t, offset)
|
||||
|
||||
def __call__(self, s: Union[str, Sequence[str]]):
|
||||
"""Tokenize prompts.
|
||||
|
||||
Args:
|
||||
s (str): prompts
|
||||
Returns:
|
||||
list[int]: token ids
|
||||
"""
|
||||
return self.model(s)
|
||||
|
||||
def indexes_containing_token(self, token):
|
||||
"""Return all the possible indexes, whose decoding output may contain
|
||||
the input token."""
|
||||
encoded = self.encode(token, add_bos=False)
|
||||
if len(encoded) > 1:
|
||||
self.logger.warning(
|
||||
f"The token {token}, its length of indexes"
|
||||
f" {encoded} is over than 1. Currently, it can not be"
|
||||
" used as stop words"
|
||||
)
|
||||
return []
|
||||
return self.model.indexes_containing_token(token)
|
@ -0,0 +1,130 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
|
||||
def get_logger(
|
||||
name: str,
|
||||
log_file: Optional[str] = None,
|
||||
log_level: int = logging.INFO,
|
||||
file_mode: str = "w",
|
||||
):
|
||||
"""Initialize and get a logger by name.
|
||||
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified, a FileHandler will also be added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level.
|
||||
file_mode (str): The file mode used in opening log file.
|
||||
Defaults to 'w'.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
# use logger in mmengine if exists.
|
||||
try:
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
if MMLogger.check_instance_created(name):
|
||||
logger = MMLogger.get_instance(name)
|
||||
else:
|
||||
logger = MMLogger.get_instance(
|
||||
name,
|
||||
logger_name=name,
|
||||
log_file=log_file,
|
||||
log_level=log_level,
|
||||
file_mode=file_mode,
|
||||
)
|
||||
return logger
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
# handle hierarchical names
|
||||
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
||||
# initialization since it is a child of "a".
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
# handle duplicate logs to the console
|
||||
for handler in logger.root.handlers:
|
||||
if type(handler) is logging.StreamHandler:
|
||||
handler.setLevel(logging.ERROR)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
handlers = [stream_handler]
|
||||
|
||||
if log_file is not None:
|
||||
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
||||
# provide an interface to change the file mode to the default
|
||||
# behaviour.
|
||||
file_handler = logging.FileHandler(log_file, file_mode)
|
||||
handlers.append(file_handler)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(log_level)
|
||||
logger.addHandler(handler)
|
||||
|
||||
logger.setLevel(log_level)
|
||||
logger_initialized[name] = True
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def filter_suffix(
|
||||
response: str, suffixes: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""Filter response with suffixes.
|
||||
|
||||
Args:
|
||||
response (str): generated response by LLMs.
|
||||
suffixes (str): a list of suffixes to be deleted.
|
||||
|
||||
Return:
|
||||
str: a clean response.
|
||||
"""
|
||||
if suffixes is None:
|
||||
return response
|
||||
for item in suffixes:
|
||||
if response.endswith(item):
|
||||
response = response[: len(response) - len(item)]
|
||||
return response
|
||||
|
||||
|
||||
# TODO remove stop_word_offsets stuff and make it clean
|
||||
def _stop_words(stop_words: List[str], tokenizer: object):
|
||||
"""return list of stop-words to numpy.ndarray."""
|
||||
import numpy as np
|
||||
|
||||
if stop_words is None:
|
||||
return None
|
||||
assert isinstance(stop_words, List) and all(
|
||||
isinstance(elem, str) for elem in stop_words
|
||||
), f"stop_words must be a list but got {type(stop_words)}"
|
||||
stop_indexes = []
|
||||
for stop_word in stop_words:
|
||||
stop_indexes += tokenizer.indexes_containing_token(stop_word)
|
||||
assert isinstance(stop_indexes, List) and all(
|
||||
isinstance(elem, int) for elem in stop_indexes
|
||||
), "invalid stop_words"
|
||||
# each id in stop_indexes represents a stop word
|
||||
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
|
||||
# detailed explanation about fastertransformer's stop_indexes
|
||||
stop_word_offsets = range(1, len(stop_indexes) + 1)
|
||||
stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(
|
||||
np.int32
|
||||
)
|
||||
return stop_words
|
@ -0,0 +1,43 @@
|
||||
from unittest.mock import patch
|
||||
from swarms.models import TimmModel
|
||||
import torch
|
||||
|
||||
|
||||
def test_timm_model_init():
|
||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
mock_list_models.assert_called_once()
|
||||
assert timm_model.model_name == model_name
|
||||
assert timm_model.pretrained == pretrained
|
||||
assert timm_model.in_chans == in_chans
|
||||
assert timm_model.models == mock_list_models.return_value
|
||||
|
||||
|
||||
def test_timm_model_call():
|
||||
with patch(
|
||||
"swarms.models.timm.create_model"
|
||||
) as mock_create_model:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
task = torch.rand(1, in_chans, 224, 224)
|
||||
result = timm_model(task)
|
||||
mock_create_model.assert_called_once_with(
|
||||
model_name, pretrained=pretrained, in_chans=in_chans
|
||||
)
|
||||
assert result == mock_create_model.return_value(task)
|
||||
|
||||
|
||||
def test_timm_model_list_models():
|
||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
result = timm_model.list_models()
|
||||
mock_list_models.assert_called_once()
|
||||
assert result == mock_list_models.return_value
|
@ -0,0 +1,35 @@
|
||||
from unittest.mock import patch
|
||||
from swarms.models.ultralytics_model import Ultralytics
|
||||
|
||||
|
||||
def test_ultralytics_init():
|
||||
with patch("swarms.models.YOLO") as mock_yolo:
|
||||
model_name = "yolov5s"
|
||||
ultralytics = Ultralytics(model_name)
|
||||
mock_yolo.assert_called_once_with(model_name)
|
||||
assert ultralytics.model_name == model_name
|
||||
assert ultralytics.model == mock_yolo.return_value
|
||||
|
||||
|
||||
def test_ultralytics_call():
|
||||
with patch("swarms.models.YOLO") as mock_yolo:
|
||||
model_name = "yolov5s"
|
||||
ultralytics = Ultralytics(model_name)
|
||||
task = "detect"
|
||||
args = (1, 2, 3)
|
||||
kwargs = {"a": "A", "b": "B"}
|
||||
result = ultralytics(task, *args, **kwargs)
|
||||
mock_yolo.return_value.assert_called_once_with(
|
||||
task, *args, **kwargs
|
||||
)
|
||||
assert result == mock_yolo.return_value.return_value
|
||||
|
||||
|
||||
|
||||
def test_ultralytics_list_models():
|
||||
with patch("swarms.models.YOLO") as mock_yolo:
|
||||
model_name = "yolov5s"
|
||||
ultralytics = Ultralytics(model_name)
|
||||
result = ultralytics.list_models()
|
||||
mock_yolo.list_models.assert_called_once()
|
||||
assert result == mock_yolo.list_models.return_value
|
Loading…
Reference in new issue