From 20273bb5db7218ad727ee17654b9455cf3a87722 Mon Sep 17 00:00:00 2001
From: Kye <kye@apacmediasolutions.com>
Date: Sat, 11 Nov 2023 10:06:36 -0500
Subject: [PATCH] tests, workflow fixes + torch verison

Former-commit-id: 06f469b21f89d6937bc2f3b0f17146fe883feb5a
---
 .github/workflows/docs.yml                 |   1 +
 pyproject.toml                             |   2 +-
 requirements.txt                           |   2 +-
 swarms/models/openai_chat.py               |   2 +-
 swarms/models/whisperx.py                  | 126 +++-
 swarms/tools/stt.py                        | 125 ----
 tests/models/kosmos2.py                    | 365 ++++++++++
 tests/models/whisperx.py                   | 206 ++++++
 tests/tools/base.py                        | 802 +++++++++++++++++++++
 tests/utils/subprocess_code_interpreter.py | 246 +++++++
 10 files changed, 1748 insertions(+), 129 deletions(-)
 delete mode 100644 swarms/tools/stt.py
 create mode 100644 tests/models/kosmos2.py
 create mode 100644 tests/models/whisperx.py
 create mode 100644 tests/tools/base.py
 create mode 100644 tests/utils/subprocess_code_interpreter.py

diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 0f89cb4c..a5a31f4b 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -15,5 +15,6 @@ jobs:
         with:
           python-version: 3.x
       - run: pip install mkdocs-material
+      - run: pip install mkdocs-glightbox
       - run: pip install "mkdocstrings[python]"
       - run: mkdocs gh-deploy --force
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index f76f7177..dca7c789 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,7 @@ asyncio = "*"
 nest_asyncio = "*"
 einops = "*"
 google-generativeai = "*"
-torch = "*"
+torch = "2.1.0"
 langchain-experimental = "*"
 playwright = "*"
 duckduckgo-search = "*"
diff --git a/requirements.txt b/requirements.txt
index 82e519af..e1148c30 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -70,4 +70,4 @@ rich
 
 mkdocs
 mkdocs-material
-mkdocs-glightbox
+mkdocs-glightbox
\ No newline at end of file
diff --git a/swarms/models/openai_chat.py b/swarms/models/openai_chat.py
index 6ca964a2..3933d8a7 100644
--- a/swarms/models/openai_chat.py
+++ b/swarms/models/openai_chat.py
@@ -214,7 +214,7 @@ class OpenAIChat(BaseChatModel):
     # Check for classes that derive from this class (as some of them
     # may assume openai_api_key is a str)
     # openai_api_key: Optional[str] = Field(default=None, alias="api_key")
-    openai_api_key = "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"
+    openai_api_key: Optional[str] = Field(default=None, alias="api_key")
     """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
     openai_api_base: Optional[str] = Field(default=None, alias="base_url")
     """Base URL path for API requests, leave blank if not using a proxy or service 
diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx.py
index 1731daa1..102ae7d7 100644
--- a/swarms/models/whisperx.py
+++ b/swarms/models/whisperx.py
@@ -1 +1,125 @@
-"""An ultra fast speech to text model."""
+# speech to text tool
+
+import os
+import subprocess
+
+import whisperx
+from pydub import AudioSegment
+from pytube import YouTube
+
+
+class WhisperX:
+    def __init__(
+        self,
+        video_url,
+        audio_format="mp3",
+        device="cuda",
+        batch_size=16,
+        compute_type="float16",
+        hf_api_key=None,
+    ):
+        """
+        # Example usage
+        video_url = "url"
+        speech_to_text = SpeechToText(video_url)
+        transcription = speech_to_text.transcribe_youtube_video()
+        print(transcription)
+
+        """
+        self.video_url = video_url
+        self.audio_format = audio_format
+        self.device = device
+        self.batch_size = batch_size
+        self.compute_type = compute_type
+        self.hf_api_key = hf_api_key
+
+    def install(self):
+        subprocess.run(["pip", "install", "whisperx"])
+        subprocess.run(["pip", "install", "pytube"])
+        subprocess.run(["pip", "install", "pydub"])
+
+    def download_youtube_video(self):
+        audio_file = f"video.{self.audio_format}"
+
+        # Download video πŸ“₯
+        yt = YouTube(self.video_url)
+        yt_stream = yt.streams.filter(only_audio=True).first()
+        yt_stream.download(filename="video.mp4")
+
+        # Convert video to audio 🎧
+        video = AudioSegment.from_file("video.mp4", format="mp4")
+        video.export(audio_file, format=self.audio_format)
+        os.remove("video.mp4")
+
+        return audio_file
+
+    def transcribe_youtube_video(self):
+        audio_file = self.download_youtube_video()
+
+        device = "cuda"
+        batch_size = 16
+        compute_type = "float16"
+
+        # 1. Transcribe with original Whisper (batched) πŸ—£οΈ
+        model = whisperx.load_model("large-v2", device, compute_type=compute_type)
+        audio = whisperx.load_audio(audio_file)
+        result = model.transcribe(audio, batch_size=batch_size)
+
+        # 2. Align Whisper output πŸ”
+        model_a, metadata = whisperx.load_align_model(
+            language_code=result["language"], device=device
+        )
+        result = whisperx.align(
+            result["segments"],
+            model_a,
+            metadata,
+            audio,
+            device,
+            return_char_alignments=False,
+        )
+
+        # 3. Assign speaker labels 🏷️
+        diarize_model = whisperx.DiarizationPipeline(
+            use_auth_token=self.hf_api_key, device=device
+        )
+        diarize_model(audio_file)
+
+        try:
+            segments = result["segments"]
+            transcription = " ".join(segment["text"] for segment in segments)
+            return transcription
+        except KeyError:
+            print("The key 'segments' is not found in the result.")
+
+    def transcribe(self, audio_file):
+        model = whisperx.load_model("large-v2", self.device, self.compute_type)
+        audio = whisperx.load_audio(audio_file)
+        result = model.transcribe(audio, batch_size=self.batch_size)
+
+        # 2. Align Whisper output πŸ”
+        model_a, metadata = whisperx.load_align_model(
+            language_code=result["language"], device=self.device
+        )
+
+        result = whisperx.align(
+            result["segments"],
+            model_a,
+            metadata,
+            audio,
+            self.device,
+            return_char_alignments=False,
+        )
+
+        # 3. Assign speaker labels 🏷️
+        diarize_model = whisperx.DiarizationPipeline(
+            use_auth_token=self.hf_api_key, device=self.device
+        )
+
+        diarize_model(audio_file)
+
+        try:
+            segments = result["segments"]
+            transcription = " ".join(segment["text"] for segment in segments)
+            return transcription
+        except KeyError:
+            print("The key 'segments' is not found in the result.")
diff --git a/swarms/tools/stt.py b/swarms/tools/stt.py
deleted file mode 100644
index cfe3e656..00000000
--- a/swarms/tools/stt.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# speech to text tool
-
-import os
-import subprocess
-
-import whisperx
-from pydub import AudioSegment
-from pytube import YouTube
-
-
-class SpeechToText:
-    def __init__(
-        self,
-        video_url,
-        audio_format="mp3",
-        device="cuda",
-        batch_size=16,
-        compute_type="float16",
-        hf_api_key=None,
-    ):
-        """
-        # Example usage
-        video_url = "url"
-        speech_to_text = SpeechToText(video_url)
-        transcription = speech_to_text.transcribe_youtube_video()
-        print(transcription)
-
-        """
-        self.video_url = video_url
-        self.audio_format = audio_format
-        self.device = device
-        self.batch_size = batch_size
-        self.compute_type = compute_type
-        self.hf_api_key = hf_api_key
-
-    def install(self):
-        subprocess.run(["pip", "install", "whisperx"])
-        subprocess.run(["pip", "install", "pytube"])
-        subprocess.run(["pip", "install", "pydub"])
-
-    def download_youtube_video(self):
-        audio_file = f"video.{self.audio_format}"
-
-        # Download video πŸ“₯
-        yt = YouTube(self.video_url)
-        yt_stream = yt.streams.filter(only_audio=True).first()
-        yt_stream.download(filename="video.mp4")
-
-        # Convert video to audio 🎧
-        video = AudioSegment.from_file("video.mp4", format="mp4")
-        video.export(audio_file, format=self.audio_format)
-        os.remove("video.mp4")
-
-        return audio_file
-
-    def transcribe_youtube_video(self):
-        audio_file = self.download_youtube_video()
-
-        device = "cuda"
-        batch_size = 16
-        compute_type = "float16"
-
-        # 1. Transcribe with original Whisper (batched) πŸ—£οΈ
-        model = whisperx.load_model("large-v2", device, compute_type=compute_type)
-        audio = whisperx.load_audio(audio_file)
-        result = model.transcribe(audio, batch_size=batch_size)
-
-        # 2. Align Whisper output πŸ”
-        model_a, metadata = whisperx.load_align_model(
-            language_code=result["language"], device=device
-        )
-        result = whisperx.align(
-            result["segments"],
-            model_a,
-            metadata,
-            audio,
-            device,
-            return_char_alignments=False,
-        )
-
-        # 3. Assign speaker labels 🏷️
-        diarize_model = whisperx.DiarizationPipeline(
-            use_auth_token=self.hf_api_key, device=device
-        )
-        diarize_model(audio_file)
-
-        try:
-            segments = result["segments"]
-            transcription = " ".join(segment["text"] for segment in segments)
-            return transcription
-        except KeyError:
-            print("The key 'segments' is not found in the result.")
-
-    def transcribe(self, audio_file):
-        model = whisperx.load_model("large-v2", self.device, self.compute_type)
-        audio = whisperx.load_audio(audio_file)
-        result = model.transcribe(audio, batch_size=self.batch_size)
-
-        # 2. Align Whisper output πŸ”
-        model_a, metadata = whisperx.load_align_model(
-            language_code=result["language"], device=self.device
-        )
-
-        result = whisperx.align(
-            result["segments"],
-            model_a,
-            metadata,
-            audio,
-            self.device,
-            return_char_alignments=False,
-        )
-
-        # 3. Assign speaker labels 🏷️
-        diarize_model = whisperx.DiarizationPipeline(
-            use_auth_token=self.hf_api_key, device=self.device
-        )
-
-        diarize_model(audio_file)
-
-        try:
-            segments = result["segments"]
-            transcription = " ".join(segment["text"] for segment in segments)
-            return transcription
-        except KeyError:
-            print("The key 'segments' is not found in the result.")
diff --git a/tests/models/kosmos2.py b/tests/models/kosmos2.py
new file mode 100644
index 00000000..2ff01092
--- /dev/null
+++ b/tests/models/kosmos2.py
@@ -0,0 +1,365 @@
+import pytest
+import os
+from PIL import Image
+from swarms.models.kosmos2 import Kosmos2, Detections
+
+
+# Fixture for a sample image
+@pytest.fixture
+def sample_image():
+    image = Image.new("RGB", (224, 224))
+    return image
+
+
+# Fixture for initializing Kosmos2
+@pytest.fixture
+def kosmos2():
+    return Kosmos2.initialize()
+
+
+# Test Kosmos2 initialization
+def test_kosmos2_initialization(kosmos2):
+    assert kosmos2 is not None
+
+
+# Test Kosmos2 with a sample image
+def test_kosmos2_with_sample_image(kosmos2, sample_image):
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 0
+    )
+
+
+# Mocked extract_entities function for testing
+def mock_extract_entities(text):
+    return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))]
+
+
+# Mocked process_entities_to_detections function for testing
+def mock_process_entities_to_detections(entities, image):
+    return Detections(
+        xyxy=[(10, 20, 30, 40), (50, 60, 70, 80)],
+        class_id=[0, 0],
+        confidence=[1.0, 1.0],
+    )
+
+
+# Test Kosmos2 with mocked entity extraction and detection
+def test_kosmos2_with_mocked_extraction_and_detection(
+    kosmos2, sample_image, monkeypatch
+):
+    monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
+    monkeypatch.setattr(
+        kosmos2, "process_entities_to_detections", mock_process_entities_to_detections
+    )
+
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 2
+    )
+
+
+# Test Kosmos2 with empty entity extraction
+def test_kosmos2_with_empty_extraction(kosmos2, sample_image, monkeypatch):
+    monkeypatch.setattr(kosmos2, "extract_entities", lambda x: [])
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 0
+    )
+
+
+# Test Kosmos2 with invalid image path
+def test_kosmos2_with_invalid_image_path(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2(img="invalid_image_path.jpg")
+
+
+# Additional tests can be added for various scenarios and edge cases
+
+
+# Test Kosmos2 with a larger image
+def test_kosmos2_with_large_image(kosmos2):
+    large_image = Image.new("RGB", (1024, 768))
+    detections = kosmos2(img=large_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 0
+    )
+
+
+# Test Kosmos2 with different image formats
+def test_kosmos2_with_different_image_formats(kosmos2, tmp_path):
+    # Create a temporary directory
+    temp_dir = tmp_path / "images"
+    temp_dir.mkdir()
+
+    # Create sample images in different formats
+    image_formats = ["jpeg", "png", "gif", "bmp"]
+    for format in image_formats:
+        image_path = temp_dir / f"sample_image.{format}"
+        Image.new("RGB", (224, 224)).save(image_path)
+
+    # Test Kosmos2 with each image format
+    for format in image_formats:
+        image_path = temp_dir / f"sample_image.{format}"
+        detections = kosmos2(img=image_path)
+        assert isinstance(detections, Detections)
+        assert (
+            len(detections.xyxy)
+            == len(detections.class_id)
+            == len(detections.confidence)
+            == 0
+        )
+
+
+# Test Kosmos2 with a non-existent model
+def test_kosmos2_with_non_existent_model(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2.model = None
+        kosmos2(img="sample_image.jpg")
+
+
+# Test Kosmos2 with a non-existent processor
+def test_kosmos2_with_non_existent_processor(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2.processor = None
+        kosmos2(img="sample_image.jpg")
+
+
+# Test Kosmos2 with missing image
+def test_kosmos2_with_missing_image(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2(img="non_existent_image.jpg")
+
+
+# ... (previous tests)
+
+
+# Test Kosmos2 with a non-existent model and processor
+def test_kosmos2_with_non_existent_model_and_processor(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2.model = None
+        kosmos2.processor = None
+        kosmos2(img="sample_image.jpg")
+
+
+# Test Kosmos2 with a corrupted image
+def test_kosmos2_with_corrupted_image(kosmos2, tmp_path):
+    # Create a temporary directory
+    temp_dir = tmp_path / "images"
+    temp_dir.mkdir()
+
+    # Create a corrupted image
+    corrupted_image_path = temp_dir / "corrupted_image.jpg"
+    with open(corrupted_image_path, "wb") as f:
+        f.write(b"corrupted data")
+
+    with pytest.raises(Exception):
+        kosmos2(img=corrupted_image_path)
+
+
+# Test Kosmos2 with a large batch size
+def test_kosmos2_with_large_batch_size(kosmos2, sample_image):
+    kosmos2.batch_size = 32
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 0
+    )
+
+
+# Test Kosmos2 with an invalid compute type
+def test_kosmos2_with_invalid_compute_type(kosmos2, sample_image):
+    kosmos2.compute_type = "invalid_compute_type"
+    with pytest.raises(Exception):
+        kosmos2(img=sample_image)
+
+
+# Test Kosmos2 with a valid HF API key
+def test_kosmos2_with_valid_hf_api_key(kosmos2, sample_image):
+    kosmos2.hf_api_key = "valid_api_key"
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 2
+    )
+
+
+# Test Kosmos2 with an invalid HF API key
+def test_kosmos2_with_invalid_hf_api_key(kosmos2, sample_image):
+    kosmos2.hf_api_key = "invalid_api_key"
+    with pytest.raises(Exception):
+        kosmos2(img=sample_image)
+
+
+# Test Kosmos2 with a very long generated text
+def test_kosmos2_with_long_generated_text(kosmos2, sample_image, monkeypatch):
+    def mock_generate_text(*args, **kwargs):
+        return "A" * 10000
+
+    monkeypatch.setattr(kosmos2.model, "generate", mock_generate_text)
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 0
+    )
+
+
+# Test Kosmos2 with entities containing special characters
+def test_kosmos2_with_entities_containing_special_characters(
+    kosmos2, sample_image, monkeypatch
+):
+    def mock_extract_entities(text):
+        return [("entity1 with special characters (ΓΌ, ΓΆ, etc.)", (0.1, 0.2, 0.3, 0.4))]
+
+    monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 1
+    )
+
+
+# Test Kosmos2 with image containing multiple objects
+def test_kosmos2_with_image_containing_multiple_objects(
+    kosmos2, sample_image, monkeypatch
+):
+    def mock_extract_entities(text):
+        return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))]
+
+    monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 2
+    )
+
+
+# Test Kosmos2 with image containing no objects
+def test_kosmos2_with_image_containing_no_objects(kosmos2, sample_image, monkeypatch):
+    def mock_extract_entities(text):
+        return []
+
+    monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 0
+    )
+
+
+# Test Kosmos2 with a valid YouTube video URL
+def test_kosmos2_with_valid_youtube_video_url(kosmos2):
+    youtube_video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
+    detections = kosmos2(video_url=youtube_video_url)
+    assert isinstance(detections, Detections)
+    assert (
+        len(detections.xyxy)
+        == len(detections.class_id)
+        == len(detections.confidence)
+        == 2
+    )
+
+
+# Test Kosmos2 with an invalid YouTube video URL
+def test_kosmos2_with_invalid_youtube_video_url(kosmos2):
+    invalid_youtube_video_url = "https://www.youtube.com/invalid_video"
+    with pytest.raises(Exception):
+        kosmos2(video_url=invalid_youtube_video_url)
+
+
+# Test Kosmos2 with no YouTube video URL provided
+def test_kosmos2_with_no_youtube_video_url(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2(video_url=None)
+
+
+# Test Kosmos2 installation
+def test_kosmos2_installation():
+    kosmos2 = Kosmos2()
+    kosmos2.install()
+    assert os.path.exists("video.mp4")
+    assert os.path.exists("video.mp3")
+    os.remove("video.mp4")
+    os.remove("video.mp3")
+
+
+# Test Kosmos2 termination
+def test_kosmos2_termination(kosmos2):
+    kosmos2.terminate()
+    assert kosmos2.process is None
+
+
+# Test Kosmos2 start_process method
+def test_kosmos2_start_process(kosmos2):
+    kosmos2.start_process()
+    assert kosmos2.process is not None
+
+
+# Test Kosmos2 preprocess_code method
+def test_kosmos2_preprocess_code(kosmos2):
+    code = "print('Hello, World!')"
+    preprocessed_code = kosmos2.preprocess_code(code)
+    assert isinstance(preprocessed_code, str)
+    assert "end_of_execution" in preprocessed_code
+
+
+# Test Kosmos2 run method with debug mode
+def test_kosmos2_run_with_debug_mode(kosmos2, sample_image):
+    kosmos2.debug_mode = True
+    detections = kosmos2(img=sample_image)
+    assert isinstance(detections, Detections)
+
+
+# Test Kosmos2 handle_stream_output method
+def test_kosmos2_handle_stream_output(kosmos2):
+    stream_output = "Sample output"
+    kosmos2.handle_stream_output(stream_output, is_error=False)
+
+
+# Test Kosmos2 run method with invalid image path
+def test_kosmos2_run_with_invalid_image_path(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2.run(img="invalid_image_path.jpg")
+
+
+# Test Kosmos2 run method with invalid video URL
+def test_kosmos2_run_with_invalid_video_url(kosmos2):
+    with pytest.raises(Exception):
+        kosmos2.run(video_url="invalid_video_url")
+
+
+# ... (more tests)
diff --git a/tests/models/whisperx.py b/tests/models/whisperx.py
new file mode 100644
index 00000000..17a28857
--- /dev/null
+++ b/tests/models/whisperx.py
@@ -0,0 +1,206 @@
+import os
+import subprocess
+import tempfile
+from unittest.mock import Mock, patch
+
+import pytest
+import whisperx
+from pydub import AudioSegment
+from pytube import YouTube
+from your_module import SpeechToText
+
+from swarms.models.whisperx import SpeechToText
+
+
+# Fixture to create a temporary directory for testing
+@pytest.fixture
+def temp_dir():
+    with tempfile.TemporaryDirectory() as tempdir:
+        yield tempdir
+
+
+# Mock subprocess.run to prevent actual installation during tests
+@patch.object(subprocess, "run")
+def test_speech_to_text_install(mock_run):
+    stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
+    stt.install()
+    mock_run.assert_called_with(["pip", "install", "whisperx"])
+
+
+# Mock pytube.YouTube and pytube.Streams for download tests
+@patch("pytube.YouTube")
+@patch.object(YouTube, "streams")
+def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir):
+    # Mock YouTube and streams
+    video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
+    mock_stream = mock_streams().filter().first()
+    mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4")
+    mock_youtube.return_value = mock_youtube
+    mock_youtube.streams = mock_streams
+
+    stt = SpeechToText(video_url)
+    audio_file = stt.download_youtube_video()
+
+    assert os.path.exists(audio_file)
+    assert audio_file.endswith(".mp3")
+
+
+# Mock whisperx.load_model and whisperx.load_audio for transcribe tests
+@patch("whisperx.load_model")
+@patch("whisperx.load_audio")
+@patch("whisperx.load_align_model")
+@patch("whisperx.align")
+@patch.object(whisperx.DiarizationPipeline, "__call__")
+def test_speech_to_text_transcribe_youtube_video(
+    mock_diarization,
+    mock_align,
+    mock_align_model,
+    mock_load_audio,
+    mock_load_model,
+    temp_dir,
+):
+    # Mock whisperx functions
+    mock_load_model.return_value = mock_load_model
+    mock_load_model.transcribe.return_value = {
+        "language": "en",
+        "segments": [{"text": "Hello, World!"}],
+    }
+
+    mock_load_audio.return_value = "audio_path"
+    mock_align_model.return_value = (mock_align_model, "metadata")
+    mock_align.return_value = {"segments": [{"text": "Hello, World!"}]}
+
+    # Mock diarization pipeline
+    mock_diarization.return_value = None
+
+    video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
+    stt = SpeechToText(video_url)
+    transcription = stt.transcribe_youtube_video()
+
+    assert transcription == "Hello, World!"
+
+
+# More tests for different scenarios and edge cases can be added here.
+
+
+# Test transcribe method with provided audio file
+def test_speech_to_text_transcribe_audio_file(temp_dir):
+    # Create a temporary audio file
+    audio_file = os.path.join(temp_dir, "test_audio.mp3")
+    AudioSegment.silent(duration=500).export(audio_file, format="mp3")
+
+    stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
+    transcription = stt.transcribe(audio_file)
+
+    assert transcription == ""
+
+
+# Test transcribe method when Whisperx fails
+@patch("whisperx.load_model")
+@patch("whisperx.load_audio")
+def test_speech_to_text_transcribe_whisperx_failure(
+    mock_load_audio, mock_load_model, temp_dir
+):
+    # Mock whisperx functions to raise an exception
+    mock_load_model.side_effect = Exception("Whisperx failed")
+    mock_load_audio.return_value = "audio_path"
+
+    stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
+    transcription = stt.transcribe("audio_path")
+
+    assert transcription == "Whisperx failed"
+
+
+# Test transcribe method with missing 'segments' key in Whisperx output
+@patch("whisperx.load_model")
+@patch("whisperx.load_audio")
+@patch("whisperx.load_align_model")
+@patch("whisperx.align")
+@patch.object(whisperx.DiarizationPipeline, "__call__")
+def test_speech_to_text_transcribe_missing_segments(
+    mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model
+):
+    # Mock whisperx functions to return incomplete output
+    mock_load_model.return_value = mock_load_model
+    mock_load_model.transcribe.return_value = {"language": "en"}
+
+    mock_load_audio.return_value = "audio_path"
+    mock_align_model.return_value = (mock_align_model, "metadata")
+    mock_align.return_value = {}
+
+    # Mock diarization pipeline
+    mock_diarization.return_value = None
+
+    stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
+    transcription = stt.transcribe("audio_path")
+
+    assert transcription == ""
+
+
+# Test transcribe method with Whisperx align failure
+@patch("whisperx.load_model")
+@patch("whisperx.load_audio")
+@patch("whisperx.load_align_model")
+@patch("whisperx.align")
+@patch.object(whisperx.DiarizationPipeline, "__call__")
+def test_speech_to_text_transcribe_align_failure(
+    mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model
+):
+    # Mock whisperx functions to raise an exception during align
+    mock_load_model.return_value = mock_load_model
+    mock_load_model.transcribe.return_value = {
+        "language": "en",
+        "segments": [{"text": "Hello, World!"}],
+    }
+
+    mock_load_audio.return_value = "audio_path"
+    mock_align_model.return_value = (mock_align_model, "metadata")
+    mock_align.side_effect = Exception("Align failed")
+
+    # Mock diarization pipeline
+    mock_diarization.return_value = None
+
+    stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
+    transcription = stt.transcribe("audio_path")
+
+    assert transcription == "Align failed"
+
+
+# Test transcribe_youtube_video when Whisperx diarization fails
+@patch("pytube.YouTube")
+@patch.object(YouTube, "streams")
+@patch("whisperx.DiarizationPipeline")
+@patch("whisperx.load_audio")
+@patch("whisperx.load_align_model")
+@patch("whisperx.align")
+def test_speech_to_text_transcribe_diarization_failure(
+    mock_align,
+    mock_align_model,
+    mock_load_audio,
+    mock_diarization,
+    mock_streams,
+    mock_youtube,
+    temp_dir,
+):
+    # Mock YouTube and streams
+    video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
+    mock_stream = mock_streams().filter().first()
+    mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4")
+    mock_youtube.return_value = mock_youtube
+    mock_youtube.streams = mock_streams
+
+    # Mock whisperx functions
+    mock_load_audio.return_value = "audio_path"
+    mock_align_model.return_value = (mock_align_model, "metadata")
+    mock_align.return_value = {"segments": [{"text": "Hello, World!"}]}
+
+    # Mock diarization pipeline to raise an exception
+    mock_diarization.side_effect = Exception("Diarization failed")
+
+    stt = SpeechToText(video_url)
+    transcription = stt.transcribe_youtube_video()
+
+    assert transcription == "Diarization failed"
+
+
+# Add more tests for other scenarios and edge cases as needed.
diff --git a/tests/tools/base.py b/tests/tools/base.py
new file mode 100644
index 00000000..4f7e2b4b
--- /dev/null
+++ b/tests/tools/base.py
@@ -0,0 +1,802 @@
+from unittest.mock import MagicMock
+
+import pytest
+from pydantic import BaseModel
+
+from swarms.tools.tool import BaseTool, Runnable, StructuredTool, Tool, tool
+
+# Define test data
+test_input = {"key1": "value1", "key2": "value2"}
+expected_output = "expected_output_value"
+
+# Test with global variables
+global_var = "global"
+
+
+# Basic tests for BaseTool
+def test_base_tool_init():
+    # Test BaseTool initialization
+    tool = BaseTool()
+    assert isinstance(tool, BaseTool)
+
+
+def test_base_tool_invoke():
+    # Test BaseTool invoke method
+    tool = BaseTool()
+    result = tool.invoke(test_input)
+    assert result == expected_output
+
+
+# Basic tests for Tool
+def test_tool_init():
+    # Test Tool initialization
+    tool = Tool()
+    assert isinstance(tool, Tool)
+
+
+def test_tool_invoke():
+    # Test Tool invoke method
+    tool = Tool()
+    result = tool.invoke(test_input)
+    assert result == expected_output
+
+
+# Basic tests for StructuredTool
+def test_structured_tool_init():
+    # Test StructuredTool initialization
+    tool = StructuredTool()
+    assert isinstance(tool, StructuredTool)
+
+
+def test_structured_tool_invoke():
+    # Test StructuredTool invoke method
+    tool = StructuredTool()
+    result = tool.invoke(test_input)
+    assert result == expected_output
+
+
+# Test additional functionality and edge cases as needed
+
+
+def test_tool_creation():
+    tool = Tool(name="test_tool", func=lambda x: x, description="Test tool")
+    assert tool.name == "test_tool"
+    assert tool.func is not None
+    assert tool.description == "Test tool"
+
+
+def test_tool_ainvoke():
+    tool = Tool(name="test_tool", func=lambda x: x, description="Test tool")
+    result = tool.ainvoke("input_data")
+    assert result == "input_data"
+
+
+def test_tool_ainvoke_with_coroutine():
+    async def async_function(input_data):
+        return input_data
+
+    tool = Tool(name="test_tool", coroutine=async_function, description="Test tool")
+    result = tool.ainvoke("input_data")
+    assert result == "input_data"
+
+
+def test_tool_args():
+    def sample_function(input_data):
+        return input_data
+
+    tool = Tool(name="test_tool", func=sample_function, description="Test tool")
+    assert tool.args == {"tool_input": {"type": "string"}}
+
+
+# Basic tests for StructuredTool class
+
+
+def test_structured_tool_creation():
+    class SampleArgsSchema:
+        pass
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=lambda x: x,
+        description="Test tool",
+        args_schema=SampleArgsSchema,
+    )
+    assert tool.name == "test_tool"
+    assert tool.func is not None
+    assert tool.description == "Test tool"
+    assert tool.args_schema == SampleArgsSchema
+
+
+def test_structured_tool_ainvoke():
+    class SampleArgsSchema:
+        pass
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=lambda x: x,
+        description="Test tool",
+        args_schema=SampleArgsSchema,
+    )
+    result = tool.ainvoke({"tool_input": "input_data"})
+    assert result == "input_data"
+
+
+def test_structured_tool_ainvoke_with_coroutine():
+    class SampleArgsSchema:
+        pass
+
+    async def async_function(input_data):
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        coroutine=async_function,
+        description="Test tool",
+        args_schema=SampleArgsSchema,
+    )
+    result = tool.ainvoke({"tool_input": "input_data"})
+    assert result == "input_data"
+
+
+def test_structured_tool_args():
+    class SampleArgsSchema:
+        pass
+
+    def sample_function(input_data):
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=sample_function,
+        description="Test tool",
+        args_schema=SampleArgsSchema,
+    )
+    assert tool.args == {"tool_input": {"type": "string"}}
+
+
+# Additional tests for exception handling
+
+
+def test_tool_ainvoke_exception():
+    tool = Tool(name="test_tool", func=None, description="Test tool")
+    with pytest.raises(NotImplementedError):
+        tool.ainvoke("input_data")
+
+
+def test_tool_ainvoke_with_coroutine_exception():
+    tool = Tool(name="test_tool", coroutine=None, description="Test tool")
+    with pytest.raises(NotImplementedError):
+        tool.ainvoke("input_data")
+
+
+def test_structured_tool_ainvoke_exception():
+    class SampleArgsSchema:
+        pass
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=None,
+        description="Test tool",
+        args_schema=SampleArgsSchema,
+    )
+    with pytest.raises(NotImplementedError):
+        tool.ainvoke({"tool_input": "input_data"})
+
+
+def test_structured_tool_ainvoke_with_coroutine_exception():
+    class SampleArgsSchema:
+        pass
+
+    tool = StructuredTool(
+        name="test_tool",
+        coroutine=None,
+        description="Test tool",
+        args_schema=SampleArgsSchema,
+    )
+    with pytest.raises(NotImplementedError):
+        tool.ainvoke({"tool_input": "input_data"})
+
+
+def test_tool_description_not_provided():
+    tool = Tool(name="test_tool", func=lambda x: x)
+    assert tool.name == "test_tool"
+    assert tool.func is not None
+    assert tool.description == ""
+
+
+def test_tool_invoke_with_callbacks():
+    def sample_function(input_data, callbacks=None):
+        if callbacks:
+            callbacks.on_start()
+            callbacks.on_finish()
+        return input_data
+
+    tool = Tool(name="test_tool", func=sample_function)
+    callbacks = MagicMock()
+    result = tool.invoke("input_data", callbacks=callbacks)
+    assert result == "input_data"
+    callbacks.on_start.assert_called_once()
+    callbacks.on_finish.assert_called_once()
+
+
+def test_tool_invoke_with_new_argument():
+    def sample_function(input_data, callbacks=None):
+        return input_data
+
+    tool = Tool(name="test_tool", func=sample_function)
+    result = tool.invoke("input_data", callbacks=None)
+    assert result == "input_data"
+
+
+def test_tool_ainvoke_with_new_argument():
+    async def async_function(input_data, callbacks=None):
+        return input_data
+
+    tool = Tool(name="test_tool", coroutine=async_function)
+    result = tool.ainvoke("input_data", callbacks=None)
+    assert result == "input_data"
+
+
+def test_tool_description_from_docstring():
+    def sample_function(input_data):
+        """Sample function docstring"""
+        return input_data
+
+    tool = Tool(name="test_tool", func=sample_function)
+    assert tool.description == "Sample function docstring"
+
+
+def test_tool_ainvoke_with_exceptions():
+    async def async_function(input_data):
+        raise ValueError("Test exception")
+
+    tool = Tool(name="test_tool", coroutine=async_function)
+    with pytest.raises(ValueError):
+        tool.ainvoke("input_data")
+
+
+# Additional tests for StructuredTool class
+
+
+def test_structured_tool_infer_schema_false():
+    def sample_function(input_data):
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=sample_function,
+        args_schema=None,
+        infer_schema=False,
+    )
+    assert tool.args_schema is None
+
+
+def test_structured_tool_ainvoke_with_callbacks():
+    class SampleArgsSchema:
+        pass
+
+    def sample_function(input_data, callbacks=None):
+        if callbacks:
+            callbacks.on_start()
+            callbacks.on_finish()
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=sample_function,
+        args_schema=SampleArgsSchema,
+    )
+    callbacks = MagicMock()
+    result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks)
+    assert result == "input_data"
+    callbacks.on_start.assert_called_once()
+    callbacks.on_finish.assert_called_once()
+
+
+def test_structured_tool_description_not_provided():
+    class SampleArgsSchema:
+        pass
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=lambda x: x,
+        args_schema=SampleArgsSchema,
+    )
+    assert tool.name == "test_tool"
+    assert tool.func is not None
+    assert tool.description == ""
+
+
+def test_structured_tool_args_schema():
+    class SampleArgsSchema:
+        pass
+
+    def sample_function(input_data):
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=sample_function,
+        args_schema=SampleArgsSchema,
+    )
+    assert tool.args_schema == SampleArgsSchema
+
+
+def test_structured_tool_args_schema_inference():
+    def sample_function(input_data):
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=sample_function,
+        args_schema=None,
+        infer_schema=True,
+    )
+    assert tool.args_schema is not None
+
+
+def test_structured_tool_ainvoke_with_new_argument():
+    class SampleArgsSchema:
+        pass
+
+    def sample_function(input_data, callbacks=None):
+        return input_data
+
+    tool = StructuredTool(
+        name="test_tool",
+        func=sample_function,
+        args_schema=SampleArgsSchema,
+    )
+    result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None)
+    assert result == "input_data"
+
+
+def test_structured_tool_ainvoke_with_exceptions():
+    class SampleArgsSchema:
+        pass
+
+    async def async_function(input_data):
+        raise ValueError("Test exception")
+
+    tool = StructuredTool(
+        name="test_tool",
+        coroutine=async_function,
+        args_schema=SampleArgsSchema,
+    )
+    with pytest.raises(ValueError):
+        tool.ainvoke({"tool_input": "input_data"})
+
+
+# Test additional functionality and edge cases
+def test_tool_with_fixture(some_fixture):
+    # Test Tool with a fixture
+    tool = Tool()
+    result = tool.invoke(test_input)
+    assert result == expected_output
+
+
+def test_structured_tool_with_fixture(some_fixture):
+    # Test StructuredTool with a fixture
+    tool = StructuredTool()
+    result = tool.invoke(test_input)
+    assert result == expected_output
+
+
+def test_base_tool_verbose_logging(caplog):
+    # Test verbose logging in BaseTool
+    tool = BaseTool(verbose=True)
+    result = tool.invoke(test_input)
+    assert result == expected_output
+    assert "Verbose logging" in caplog.text
+
+
+def test_tool_exception_handling():
+    # Test exception handling in Tool
+    tool = Tool()
+    with pytest.raises(Exception):
+        tool.invoke(test_input, raise_exception=True)
+
+
+def test_structured_tool_async_invoke():
+    # Test asynchronous invoke in StructuredTool
+    tool = StructuredTool()
+    result = tool.ainvoke(test_input)
+    assert result == expected_output
+
+
+def test_tool_async_invoke_with_fixture(some_fixture):
+    # Test asynchronous invoke with a fixture in Tool
+    tool = Tool()
+    result = tool.ainvoke(test_input)
+    assert result == expected_output
+
+
+# Add more tests for specific functionalities and edge cases as needed
+# Import necessary libraries and modules
+
+
+# Example of a mock function to be used in testing
+def mock_function(arg: str) -> str:
+    """A simple mock function for testing."""
+    return f"Processed {arg}"
+
+
+# Example of a Runnable class for testing
+class MockRunnable(Runnable):
+    # Define necessary methods and properties
+    pass
+
+
+# Fixture for creating a mock function
+@pytest.fixture
+def mock_func():
+    return mock_function
+
+
+# Fixture for creating a Runnable instance
+@pytest.fixture
+def mock_runnable():
+    return MockRunnable()
+
+
+# Basic functionality tests
+def test_tool_with_callable(mock_func):
+    # Test creating a tool with a simple callable
+    tool_instance = tool(mock_func)
+    assert isinstance(tool_instance, BaseTool)
+
+
+def test_tool_with_runnable(mock_runnable):
+    # Test creating a tool with a Runnable instance
+    tool_instance = tool(mock_runnable)
+    assert isinstance(tool_instance, BaseTool)
+
+
+# ... more basic functionality tests ...
+
+
+# Argument handling tests
+def test_tool_with_invalid_argument():
+    # Test passing an invalid argument type
+    with pytest.raises(ValueError):
+        tool(123)  # Using an integer instead of a string/callable/Runnable
+
+
+def test_tool_with_multiple_arguments(mock_func):
+    # Test passing multiple valid arguments
+    tool_instance = tool("mock", mock_func)
+    assert isinstance(tool_instance, BaseTool)
+
+
+# ... more argument handling tests ...
+
+
+# Schema inference and application tests
+class TestSchema(BaseModel):
+    arg: str
+
+
+def test_tool_with_args_schema(mock_func):
+    # Test passing a custom args_schema
+    tool_instance = tool(mock_func, args_schema=TestSchema)
+    assert tool_instance.args_schema == TestSchema
+
+
+# ... more schema tests ...
+
+
+# Exception handling tests
+def test_tool_function_without_docstring():
+    # Test that a ValueError is raised if the function lacks a docstring
+    def no_doc_func(arg: str) -> str:
+        return arg
+
+    with pytest.raises(ValueError):
+        tool(no_doc_func)
+
+
+# ... more exception tests ...
+
+
+# Decorator behavior tests
+@pytest.mark.asyncio
+async def test_async_tool_function():
+    # Test an async function with the tool decorator
+    @tool
+    async def async_func(arg: str) -> str:
+        return arg
+
+    # Add async specific assertions here
+
+
+# ... more decorator tests ...
+
+
+class MockSchema(BaseModel):
+    """Mock schema for testing args_schema."""
+
+    arg: str
+
+
+# Test suite starts here
+class TestTool:
+    # Basic Functionality Tests
+    def test_tool_with_valid_callable_creates_base_tool(self, mock_func):
+        result = tool(mock_func)
+        assert isinstance(result, BaseTool)
+
+    def test_tool_returns_correct_function_name(self, mock_func):
+        result = tool(mock_func)
+        assert result.func.__name__ == "mock_function"
+
+    # Argument Handling Tests
+    def test_tool_with_string_and_runnable(self, mock_runnable):
+        result = tool("mock_runnable", mock_runnable)
+        assert isinstance(result, BaseTool)
+
+    def test_tool_raises_error_with_invalid_arguments(self):
+        with pytest.raises(ValueError):
+            tool(123)
+
+    # Schema Inference and Application Tests
+    def test_tool_with_args_schema(self, mock_func):
+        result = tool(mock_func, args_schema=MockSchema)
+        assert result.args_schema == MockSchema
+
+    def test_tool_with_infer_schema_true(self, mock_func):
+        tool(mock_func, infer_schema=True)
+        # Assertions related to schema inference
+
+    # Return Direct Feature Tests
+    def test_tool_with_return_direct_true(self, mock_func):
+        tool(mock_func, return_direct=True)
+        # Assertions for return_direct behavior
+
+    # Error Handling Tests
+    def test_tool_raises_error_without_docstring(self):
+        def no_doc_func(arg: str) -> str:
+            return arg
+
+        with pytest.raises(ValueError):
+            tool(no_doc_func)
+
+    def test_tool_raises_error_runnable_without_object_schema(self, mock_runnable):
+        with pytest.raises(ValueError):
+            tool(mock_runnable)
+
+    # Decorator Behavior Tests
+    @pytest.mark.asyncio
+    async def test_async_tool_function(self):
+        @tool
+        async def async_func(arg: str) -> str:
+            return arg
+
+        # Assertions for async behavior
+
+    # Integration with StructuredTool and Tool Classes
+    def test_integration_with_structured_tool(self, mock_func):
+        result = tool(mock_func)
+        assert isinstance(result, StructuredTool)
+
+    # Concurrency and Async Handling Tests
+    def test_concurrency_in_tool(self, mock_func):
+        # Test related to concurrency
+        pass
+
+    # Mocking and Isolation Tests
+    def test_mocking_external_dependencies(self, mocker):
+        # Use mocker to mock external dependencies
+        pass
+
+    def test_tool_with_different_return_types(self):
+        @tool
+        def return_int(arg: str) -> int:
+            return int(arg)
+
+        result = return_int("123")
+        assert isinstance(result, int)
+        assert result == 123
+
+        @tool
+        def return_bool(arg: str) -> bool:
+            return arg.lower() in ["true", "yes"]
+
+        result = return_bool("true")
+        assert isinstance(result, bool)
+        assert result is True
+
+    # Test with multiple arguments
+    def test_tool_with_multiple_args(self):
+        @tool
+        def concat_strings(a: str, b: str) -> str:
+            return a + b
+
+        result = concat_strings("Hello", "World")
+        assert result == "HelloWorld"
+
+    # Test handling of optional arguments
+    def test_tool_with_optional_args(self):
+        @tool
+        def greet(name: str, greeting: str = "Hello") -> str:
+            return f"{greeting} {name}"
+
+        assert greet("Alice") == "Hello Alice"
+        assert greet("Alice", greeting="Hi") == "Hi Alice"
+
+    # Test with variadic arguments
+    def test_tool_with_variadic_args(self):
+        @tool
+        def sum_numbers(*numbers: int) -> int:
+            return sum(numbers)
+
+        assert sum_numbers(1, 2, 3) == 6
+        assert sum_numbers(10, 20) == 30
+
+    # Test with keyword arguments
+    def test_tool_with_kwargs(self):
+        @tool
+        def build_query(**kwargs) -> str:
+            return "&".join(f"{k}={v}" for k, v in kwargs.items())
+
+        assert build_query(a=1, b=2) == "a=1&b=2"
+        assert build_query(foo="bar") == "foo=bar"
+
+    # Test with mixed types of arguments
+    def test_tool_with_mixed_args(self):
+        @tool
+        def mixed_args(a: int, b: str, *args, **kwargs) -> str:
+            return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}"
+
+        assert mixed_args(1, "b", "c", "d", x="y", z="w") == "1b2y-w"
+
+    # Test error handling with incorrect types
+    def test_tool_error_with_incorrect_types(self):
+        @tool
+        def add_numbers(a: int, b: int) -> int:
+            return a + b
+
+        with pytest.raises(TypeError):
+            add_numbers("1", "2")
+
+    # Test with nested tools
+    def test_nested_tools(self):
+        @tool
+        def inner_tool(arg: str) -> str:
+            return f"Inner {arg}"
+
+        @tool
+        def outer_tool(arg: str) -> str:
+            return f"Outer {inner_tool(arg)}"
+
+        assert outer_tool("Test") == "Outer Inner Test"
+
+    def test_tool_with_global_variable(self):
+        @tool
+        def access_global(arg: str) -> str:
+            return f"{global_var} {arg}"
+
+        assert access_global("Var") == "global Var"
+
+    # Test with environment variables
+    def test_tool_with_env_variables(self, monkeypatch):
+        monkeypatch.setenv("TEST_VAR", "Environment")
+
+        @tool
+        def access_env_variable(arg: str) -> str:
+            import os
+
+            return f"{os.environ['TEST_VAR']} {arg}"
+
+        assert access_env_variable("Var") == "Environment Var"
+
+    # ... [Previous test cases] ...
+
+    # Test with complex data structures
+    def test_tool_with_complex_data_structures(self):
+        @tool
+        def process_data(data: dict) -> list:
+            return [data[key] for key in sorted(data.keys())]
+
+        result = process_data({"b": 2, "a": 1})
+        assert result == [1, 2]
+
+    # Test handling exceptions within the tool function
+    def test_tool_handling_internal_exceptions(self):
+        @tool
+        def function_that_raises(arg: str):
+            if arg == "error":
+                raise ValueError("Error occurred")
+            return arg
+
+        with pytest.raises(ValueError):
+            function_that_raises("error")
+        assert function_that_raises("ok") == "ok"
+
+    # Test with functions returning None
+    def test_tool_with_none_return(self):
+        @tool
+        def return_none(arg: str):
+            return None
+
+        assert return_none("anything") is None
+
+    # Test with lambda functions
+    def test_tool_with_lambda(self):
+        tool_lambda = tool(lambda x: x * 2)
+        assert tool_lambda(3) == 6
+
+    # Test with class methods
+    def test_tool_with_class_method(self):
+        class MyClass:
+            @tool
+            def method(self, arg: str) -> str:
+                return f"Method {arg}"
+
+        obj = MyClass()
+        assert obj.method("test") == "Method test"
+
+    # Test tool function with inheritance
+    def test_tool_with_inheritance(self):
+        class Parent:
+            @tool
+            def parent_method(self, arg: str) -> str:
+                return f"Parent {arg}"
+
+        class Child(Parent):
+            @tool
+            def child_method(self, arg: str) -> str:
+                return f"Child {arg}"
+
+        child_obj = Child()
+        assert child_obj.parent_method("test") == "Parent test"
+        assert child_obj.child_method("test") == "Child test"
+
+    # Test with decorators stacking
+    def test_tool_with_multiple_decorators(self):
+        def another_decorator(func):
+            def wrapper(*args, **kwargs):
+                return f"Decorated {func(*args, **kwargs)}"
+
+            return wrapper
+
+        @tool
+        @another_decorator
+        def decorated_function(arg: str):
+            return f"Function {arg}"
+
+        assert decorated_function("test") == "Decorated Function test"
+
+    # Test tool function when used in a multi-threaded environment
+    def test_tool_in_multithreaded_environment(self):
+        import threading
+
+        @tool
+        def threaded_function(arg: int) -> int:
+            return arg * 2
+
+        results = []
+
+        def thread_target():
+            results.append(threaded_function(5))
+
+        threads = [threading.Thread(target=thread_target) for _ in range(10)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        assert results == [10] * 10
+
+    # Test with recursive functions
+    def test_tool_with_recursive_function(self):
+        @tool
+        def recursive_function(n: int) -> int:
+            if n == 0:
+                return 0
+            else:
+                return n + recursive_function(n - 1)
+
+        assert recursive_function(5) == 15
+
+
+# Additional tests can be added here to cover more scenarios
diff --git a/tests/utils/subprocess_code_interpreter.py b/tests/utils/subprocess_code_interpreter.py
new file mode 100644
index 00000000..601f8a09
--- /dev/null
+++ b/tests/utils/subprocess_code_interpreter.py
@@ -0,0 +1,246 @@
+import time
+import threading
+import pytest
+import subprocess
+from swarms.utils.code_interpreter import BaseCodeInterpreter, SubprocessCodeInterpreter
+
+
+@pytest.fixture
+def subprocess_code_interpreter():
+    interpreter = SubprocessCodeInterpreter()
+    interpreter.start_cmd = "python -c"
+    yield interpreter
+    interpreter.terminate()
+
+
+def test_base_code_interpreter_init():
+    interpreter = BaseCodeInterpreter()
+    assert isinstance(interpreter, BaseCodeInterpreter)
+
+
+def test_base_code_interpreter_run_not_implemented():
+    interpreter = BaseCodeInterpreter()
+    with pytest.raises(NotImplementedError):
+        interpreter.run("code")
+
+
+def test_base_code_interpreter_terminate_not_implemented():
+    interpreter = BaseCodeInterpreter()
+    with pytest.raises(NotImplementedError):
+        interpreter.terminate()
+
+
+def test_subprocess_code_interpreter_init(subprocess_code_interpreter):
+    assert isinstance(subprocess_code_interpreter, SubprocessCodeInterpreter)
+
+
+def test_subprocess_code_interpreter_start_process(subprocess_code_interpreter):
+    subprocess_code_interpreter.start_process()
+    assert subprocess_code_interpreter.process is not None
+
+
+def test_subprocess_code_interpreter_terminate(subprocess_code_interpreter):
+    subprocess_code_interpreter.start_process()
+    subprocess_code_interpreter.terminate()
+    assert subprocess_code_interpreter.process.poll() is not None
+
+
+def test_subprocess_code_interpreter_run_success(subprocess_code_interpreter):
+    code = 'print("Hello, World!")'
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("Hello, World!" in output.get("output", "") for output in result)
+
+
+def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter):
+    code = 'print("Hello, World")\nraise ValueError("Error!")'
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("Error!" in output.get("output", "") for output in result)
+
+
+def test_subprocess_code_interpreter_run_with_keyboard_interrupt(
+    subprocess_code_interpreter,
+):
+    code = 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise KeyboardInterrupt'
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("KeyboardInterrupt" in output.get("output", "") for output in result)
+
+
+def test_subprocess_code_interpreter_run_max_retries(
+    subprocess_code_interpreter, monkeypatch
+):
+    def mock_subprocess_popen(*args, **kwargs):
+        raise subprocess.CalledProcessError(1, "mocked_cmd")
+
+    monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen)
+
+    code = 'print("Hello, World!")'
+    result = list(subprocess_code_interpreter.run(code))
+    assert any(
+        "Maximum retries reached. Could not execute code." in output.get("output", "")
+        for output in result
+    )
+
+
+def test_subprocess_code_interpreter_run_retry_on_error(
+    subprocess_code_interpreter, monkeypatch
+):
+    def mock_subprocess_popen(*args, **kwargs):
+        nonlocal popen_count
+        if popen_count == 0:
+            popen_count += 1
+            raise subprocess.CalledProcessError(1, "mocked_cmd")
+        else:
+            return subprocess.Popen(
+                "echo 'Hello, World!'",
+                shell=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE,
+            )
+
+    monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen)
+    popen_count = 0
+
+    code = 'print("Hello, World!")'
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("Hello, World!" in output.get("output", "") for output in result)
+
+
+# Add more tests to cover other aspects of the code and edge cases as needed
+
+# Import statements and fixtures from the previous code block
+
+
+def test_subprocess_code_interpreter_line_postprocessor(subprocess_code_interpreter):
+    line = "This is a test line"
+    processed_line = subprocess_code_interpreter.line_postprocessor(line)
+    assert processed_line == line  # No processing, should remain the same
+
+
+def test_subprocess_code_interpreter_preprocess_code(subprocess_code_interpreter):
+    code = 'print("Hello, World!")'
+    preprocessed_code = subprocess_code_interpreter.preprocess_code(code)
+    assert preprocessed_code == code  # No preprocessing, should remain the same
+
+
+def test_subprocess_code_interpreter_detect_active_line(subprocess_code_interpreter):
+    line = "Active line: 5"
+    active_line = subprocess_code_interpreter.detect_active_line(line)
+    assert active_line == 5
+
+
+def test_subprocess_code_interpreter_detect_end_of_execution(
+    subprocess_code_interpreter,
+):
+    line = "Execution completed."
+    end_of_execution = subprocess_code_interpreter.detect_end_of_execution(line)
+    assert end_of_execution is True
+
+
+def test_subprocess_code_interpreter_run_debug_mode(
+    subprocess_code_interpreter, capsys
+):
+    subprocess_code_interpreter.debug_mode = True
+    code = 'print("Hello, World!")'
+    result = list(subprocess_code_interpreter.run(code))
+    captured = capsys.readouterr()
+    assert "Running code:\n" in captured.out
+    assert "Received output line:\n" in captured.out
+
+
+def test_subprocess_code_interpreter_run_no_debug_mode(
+    subprocess_code_interpreter, capsys
+):
+    subprocess_code_interpreter.debug_mode = False
+    code = 'print("Hello, World!")'
+    result = list(subprocess_code_interpreter.run(code))
+    captured = capsys.readouterr()
+    assert "Running code:\n" not in captured.out
+    assert "Received output line:\n" not in captured.out
+
+
+def test_subprocess_code_interpreter_run_empty_output_queue(
+    subprocess_code_interpreter,
+):
+    code = 'print("Hello, World!")'
+    result = list(subprocess_code_interpreter.run(code))
+    assert not any("active_line" in output for output in result)
+
+
+def test_subprocess_code_interpreter_handle_stream_output_stdout(
+    subprocess_code_interpreter,
+):
+    line = "This is a test line"
+    subprocess_code_interpreter.handle_stream_output(threading.current_thread(), False)
+    subprocess_code_interpreter.process.stdout.write(line + "\n")
+    subprocess_code_interpreter.process.stdout.flush()
+    time.sleep(0.1)
+    output = subprocess_code_interpreter.output_queue.get()
+    assert output["output"] == line
+
+
+def test_subprocess_code_interpreter_handle_stream_output_stderr(
+    subprocess_code_interpreter,
+):
+    line = "This is an error line"
+    subprocess_code_interpreter.handle_stream_output(threading.current_thread(), True)
+    subprocess_code_interpreter.process.stderr.write(line + "\n")
+    subprocess_code_interpreter.process.stderr.flush()
+    time.sleep(0.1)
+    output = subprocess_code_interpreter.output_queue.get()
+    assert output["output"] == line
+
+
+def test_subprocess_code_interpreter_run_with_preprocess_code(
+    subprocess_code_interpreter, capsys
+):
+    code = 'print("Hello, World!")'
+    subprocess_code_interpreter.preprocess_code = (
+        lambda x: x.upper()
+    )  # Modify code in preprocess_code
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("Hello, World!" in output.get("output", "") for output in result)
+
+
+def test_subprocess_code_interpreter_run_with_exception(
+    subprocess_code_interpreter, capsys
+):
+    code = 'print("Hello, World!")'
+    subprocess_code_interpreter.start_cmd = (
+        "nonexistent_command"  # Force an exception during subprocess creation
+    )
+    result = list(subprocess_code_interpreter.run(code))
+    assert any(
+        "Maximum retries reached" in output.get("output", "") for output in result
+    )
+
+
+def test_subprocess_code_interpreter_run_with_active_line(
+    subprocess_code_interpreter, capsys
+):
+    code = "a = 5\nprint(a)"  # Contains an active line
+    result = list(subprocess_code_interpreter.run(code))
+    assert any(output.get("active_line") == 5 for output in result)
+
+
+def test_subprocess_code_interpreter_run_with_end_of_execution(
+    subprocess_code_interpreter, capsys
+):
+    code = 'print("Hello, World!")'  # Simple code without active line marker
+    result = list(subprocess_code_interpreter.run(code))
+    assert any(output.get("active_line") is None for output in result)
+
+
+def test_subprocess_code_interpreter_run_with_multiple_lines(
+    subprocess_code_interpreter, capsys
+):
+    code = "a = 5\nb = 10\nprint(a + b)"
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("15" in output.get("output", "") for output in result)
+
+
+def test_subprocess_code_interpreter_run_with_unicode_characters(
+    subprocess_code_interpreter, capsys
+):
+    code = 'print("γ“γ‚“γ«γ‘γ―γ€δΈ–η•Œ")'  # Contains unicode characters
+    result = list(subprocess_code_interpreter.run(code))
+    assert any("γ“γ‚“γ«γ‘γ―γ€δΈ–η•Œ" in output.get("output", "") for output in result)