Former-commit-id: 2e6efb4781
grit/923f7c6f-0958-480b-8748-ea6bbf1c2084
parent
3815c73b64
commit
4ffa418178
@ -1 +1,125 @@
|
|||||||
"""An ultra fast speech to text model."""
|
# speech to text tool
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import whisperx
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from pytube import YouTube
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperX:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
video_url,
|
||||||
|
audio_format="mp3",
|
||||||
|
device="cuda",
|
||||||
|
batch_size=16,
|
||||||
|
compute_type="float16",
|
||||||
|
hf_api_key=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
# Example usage
|
||||||
|
video_url = "url"
|
||||||
|
speech_to_text = SpeechToText(video_url)
|
||||||
|
transcription = speech_to_text.transcribe_youtube_video()
|
||||||
|
print(transcription)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.video_url = video_url
|
||||||
|
self.audio_format = audio_format
|
||||||
|
self.device = device
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.compute_type = compute_type
|
||||||
|
self.hf_api_key = hf_api_key
|
||||||
|
|
||||||
|
def install(self):
|
||||||
|
subprocess.run(["pip", "install", "whisperx"])
|
||||||
|
subprocess.run(["pip", "install", "pytube"])
|
||||||
|
subprocess.run(["pip", "install", "pydub"])
|
||||||
|
|
||||||
|
def download_youtube_video(self):
|
||||||
|
audio_file = f"video.{self.audio_format}"
|
||||||
|
|
||||||
|
# Download video 📥
|
||||||
|
yt = YouTube(self.video_url)
|
||||||
|
yt_stream = yt.streams.filter(only_audio=True).first()
|
||||||
|
yt_stream.download(filename="video.mp4")
|
||||||
|
|
||||||
|
# Convert video to audio 🎧
|
||||||
|
video = AudioSegment.from_file("video.mp4", format="mp4")
|
||||||
|
video.export(audio_file, format=self.audio_format)
|
||||||
|
os.remove("video.mp4")
|
||||||
|
|
||||||
|
return audio_file
|
||||||
|
|
||||||
|
def transcribe_youtube_video(self):
|
||||||
|
audio_file = self.download_youtube_video()
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
batch_size = 16
|
||||||
|
compute_type = "float16"
|
||||||
|
|
||||||
|
# 1. Transcribe with original Whisper (batched) 🗣️
|
||||||
|
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
||||||
|
audio = whisperx.load_audio(audio_file)
|
||||||
|
result = model.transcribe(audio, batch_size=batch_size)
|
||||||
|
|
||||||
|
# 2. Align Whisper output 🔍
|
||||||
|
model_a, metadata = whisperx.load_align_model(
|
||||||
|
language_code=result["language"], device=device
|
||||||
|
)
|
||||||
|
result = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
model_a,
|
||||||
|
metadata,
|
||||||
|
audio,
|
||||||
|
device,
|
||||||
|
return_char_alignments=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Assign speaker labels 🏷️
|
||||||
|
diarize_model = whisperx.DiarizationPipeline(
|
||||||
|
use_auth_token=self.hf_api_key, device=device
|
||||||
|
)
|
||||||
|
diarize_model(audio_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
segments = result["segments"]
|
||||||
|
transcription = " ".join(segment["text"] for segment in segments)
|
||||||
|
return transcription
|
||||||
|
except KeyError:
|
||||||
|
print("The key 'segments' is not found in the result.")
|
||||||
|
|
||||||
|
def transcribe(self, audio_file):
|
||||||
|
model = whisperx.load_model("large-v2", self.device, self.compute_type)
|
||||||
|
audio = whisperx.load_audio(audio_file)
|
||||||
|
result = model.transcribe(audio, batch_size=self.batch_size)
|
||||||
|
|
||||||
|
# 2. Align Whisper output 🔍
|
||||||
|
model_a, metadata = whisperx.load_align_model(
|
||||||
|
language_code=result["language"], device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
result = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
model_a,
|
||||||
|
metadata,
|
||||||
|
audio,
|
||||||
|
self.device,
|
||||||
|
return_char_alignments=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Assign speaker labels 🏷️
|
||||||
|
diarize_model = whisperx.DiarizationPipeline(
|
||||||
|
use_auth_token=self.hf_api_key, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
diarize_model(audio_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
segments = result["segments"]
|
||||||
|
transcription = " ".join(segment["text"] for segment in segments)
|
||||||
|
return transcription
|
||||||
|
except KeyError:
|
||||||
|
print("The key 'segments' is not found in the result.")
|
||||||
|
@ -1,125 +0,0 @@
|
|||||||
# speech to text tool
|
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import whisperx
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from pytube import YouTube
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToText:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
video_url,
|
|
||||||
audio_format="mp3",
|
|
||||||
device="cuda",
|
|
||||||
batch_size=16,
|
|
||||||
compute_type="float16",
|
|
||||||
hf_api_key=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
# Example usage
|
|
||||||
video_url = "url"
|
|
||||||
speech_to_text = SpeechToText(video_url)
|
|
||||||
transcription = speech_to_text.transcribe_youtube_video()
|
|
||||||
print(transcription)
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.video_url = video_url
|
|
||||||
self.audio_format = audio_format
|
|
||||||
self.device = device
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.compute_type = compute_type
|
|
||||||
self.hf_api_key = hf_api_key
|
|
||||||
|
|
||||||
def install(self):
|
|
||||||
subprocess.run(["pip", "install", "whisperx"])
|
|
||||||
subprocess.run(["pip", "install", "pytube"])
|
|
||||||
subprocess.run(["pip", "install", "pydub"])
|
|
||||||
|
|
||||||
def download_youtube_video(self):
|
|
||||||
audio_file = f"video.{self.audio_format}"
|
|
||||||
|
|
||||||
# Download video 📥
|
|
||||||
yt = YouTube(self.video_url)
|
|
||||||
yt_stream = yt.streams.filter(only_audio=True).first()
|
|
||||||
yt_stream.download(filename="video.mp4")
|
|
||||||
|
|
||||||
# Convert video to audio 🎧
|
|
||||||
video = AudioSegment.from_file("video.mp4", format="mp4")
|
|
||||||
video.export(audio_file, format=self.audio_format)
|
|
||||||
os.remove("video.mp4")
|
|
||||||
|
|
||||||
return audio_file
|
|
||||||
|
|
||||||
def transcribe_youtube_video(self):
|
|
||||||
audio_file = self.download_youtube_video()
|
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
batch_size = 16
|
|
||||||
compute_type = "float16"
|
|
||||||
|
|
||||||
# 1. Transcribe with original Whisper (batched) 🗣️
|
|
||||||
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
|
||||||
audio = whisperx.load_audio(audio_file)
|
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
|
||||||
|
|
||||||
# 2. Align Whisper output 🔍
|
|
||||||
model_a, metadata = whisperx.load_align_model(
|
|
||||||
language_code=result["language"], device=device
|
|
||||||
)
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
model_a,
|
|
||||||
metadata,
|
|
||||||
audio,
|
|
||||||
device,
|
|
||||||
return_char_alignments=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Assign speaker labels 🏷️
|
|
||||||
diarize_model = whisperx.DiarizationPipeline(
|
|
||||||
use_auth_token=self.hf_api_key, device=device
|
|
||||||
)
|
|
||||||
diarize_model(audio_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
segments = result["segments"]
|
|
||||||
transcription = " ".join(segment["text"] for segment in segments)
|
|
||||||
return transcription
|
|
||||||
except KeyError:
|
|
||||||
print("The key 'segments' is not found in the result.")
|
|
||||||
|
|
||||||
def transcribe(self, audio_file):
|
|
||||||
model = whisperx.load_model("large-v2", self.device, self.compute_type)
|
|
||||||
audio = whisperx.load_audio(audio_file)
|
|
||||||
result = model.transcribe(audio, batch_size=self.batch_size)
|
|
||||||
|
|
||||||
# 2. Align Whisper output 🔍
|
|
||||||
model_a, metadata = whisperx.load_align_model(
|
|
||||||
language_code=result["language"], device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
model_a,
|
|
||||||
metadata,
|
|
||||||
audio,
|
|
||||||
self.device,
|
|
||||||
return_char_alignments=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Assign speaker labels 🏷️
|
|
||||||
diarize_model = whisperx.DiarizationPipeline(
|
|
||||||
use_auth_token=self.hf_api_key, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
diarize_model(audio_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
segments = result["segments"]
|
|
||||||
transcription = " ".join(segment["text"] for segment in segments)
|
|
||||||
return transcription
|
|
||||||
except KeyError:
|
|
||||||
print("The key 'segments' is not found in the result.")
|
|
@ -0,0 +1,365 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from swarms.models.kosmos2 import Kosmos2, Detections
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture for a sample image
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_image():
|
||||||
|
image = Image.new("RGB", (224, 224))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture for initializing Kosmos2
|
||||||
|
@pytest.fixture
|
||||||
|
def kosmos2():
|
||||||
|
return Kosmos2.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 initialization
|
||||||
|
def test_kosmos2_initialization(kosmos2):
|
||||||
|
assert kosmos2 is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a sample image
|
||||||
|
def test_kosmos2_with_sample_image(kosmos2, sample_image):
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mocked extract_entities function for testing
|
||||||
|
def mock_extract_entities(text):
|
||||||
|
return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))]
|
||||||
|
|
||||||
|
|
||||||
|
# Mocked process_entities_to_detections function for testing
|
||||||
|
def mock_process_entities_to_detections(entities, image):
|
||||||
|
return Detections(
|
||||||
|
xyxy=[(10, 20, 30, 40), (50, 60, 70, 80)],
|
||||||
|
class_id=[0, 0],
|
||||||
|
confidence=[1.0, 1.0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with mocked entity extraction and detection
|
||||||
|
def test_kosmos2_with_mocked_extraction_and_detection(
|
||||||
|
kosmos2, sample_image, monkeypatch
|
||||||
|
):
|
||||||
|
monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kosmos2, "process_entities_to_detections", mock_process_entities_to_detections
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with empty entity extraction
|
||||||
|
def test_kosmos2_with_empty_extraction(kosmos2, sample_image, monkeypatch):
|
||||||
|
monkeypatch.setattr(kosmos2, "extract_entities", lambda x: [])
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with invalid image path
|
||||||
|
def test_kosmos2_with_invalid_image_path(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(img="invalid_image_path.jpg")
|
||||||
|
|
||||||
|
|
||||||
|
# Additional tests can be added for various scenarios and edge cases
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a larger image
|
||||||
|
def test_kosmos2_with_large_image(kosmos2):
|
||||||
|
large_image = Image.new("RGB", (1024, 768))
|
||||||
|
detections = kosmos2(img=large_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with different image formats
|
||||||
|
def test_kosmos2_with_different_image_formats(kosmos2, tmp_path):
|
||||||
|
# Create a temporary directory
|
||||||
|
temp_dir = tmp_path / "images"
|
||||||
|
temp_dir.mkdir()
|
||||||
|
|
||||||
|
# Create sample images in different formats
|
||||||
|
image_formats = ["jpeg", "png", "gif", "bmp"]
|
||||||
|
for format in image_formats:
|
||||||
|
image_path = temp_dir / f"sample_image.{format}"
|
||||||
|
Image.new("RGB", (224, 224)).save(image_path)
|
||||||
|
|
||||||
|
# Test Kosmos2 with each image format
|
||||||
|
for format in image_formats:
|
||||||
|
image_path = temp_dir / f"sample_image.{format}"
|
||||||
|
detections = kosmos2(img=image_path)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a non-existent model
|
||||||
|
def test_kosmos2_with_non_existent_model(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2.model = None
|
||||||
|
kosmos2(img="sample_image.jpg")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a non-existent processor
|
||||||
|
def test_kosmos2_with_non_existent_processor(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2.processor = None
|
||||||
|
kosmos2(img="sample_image.jpg")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with missing image
|
||||||
|
def test_kosmos2_with_missing_image(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(img="non_existent_image.jpg")
|
||||||
|
|
||||||
|
|
||||||
|
# ... (previous tests)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a non-existent model and processor
|
||||||
|
def test_kosmos2_with_non_existent_model_and_processor(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2.model = None
|
||||||
|
kosmos2.processor = None
|
||||||
|
kosmos2(img="sample_image.jpg")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a corrupted image
|
||||||
|
def test_kosmos2_with_corrupted_image(kosmos2, tmp_path):
|
||||||
|
# Create a temporary directory
|
||||||
|
temp_dir = tmp_path / "images"
|
||||||
|
temp_dir.mkdir()
|
||||||
|
|
||||||
|
# Create a corrupted image
|
||||||
|
corrupted_image_path = temp_dir / "corrupted_image.jpg"
|
||||||
|
with open(corrupted_image_path, "wb") as f:
|
||||||
|
f.write(b"corrupted data")
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(img=corrupted_image_path)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a large batch size
|
||||||
|
def test_kosmos2_with_large_batch_size(kosmos2, sample_image):
|
||||||
|
kosmos2.batch_size = 32
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with an invalid compute type
|
||||||
|
def test_kosmos2_with_invalid_compute_type(kosmos2, sample_image):
|
||||||
|
kosmos2.compute_type = "invalid_compute_type"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(img=sample_image)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a valid HF API key
|
||||||
|
def test_kosmos2_with_valid_hf_api_key(kosmos2, sample_image):
|
||||||
|
kosmos2.hf_api_key = "valid_api_key"
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with an invalid HF API key
|
||||||
|
def test_kosmos2_with_invalid_hf_api_key(kosmos2, sample_image):
|
||||||
|
kosmos2.hf_api_key = "invalid_api_key"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(img=sample_image)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a very long generated text
|
||||||
|
def test_kosmos2_with_long_generated_text(kosmos2, sample_image, monkeypatch):
|
||||||
|
def mock_generate_text(*args, **kwargs):
|
||||||
|
return "A" * 10000
|
||||||
|
|
||||||
|
monkeypatch.setattr(kosmos2.model, "generate", mock_generate_text)
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with entities containing special characters
|
||||||
|
def test_kosmos2_with_entities_containing_special_characters(
|
||||||
|
kosmos2, sample_image, monkeypatch
|
||||||
|
):
|
||||||
|
def mock_extract_entities(text):
|
||||||
|
return [("entity1 with special characters (ü, ö, etc.)", (0.1, 0.2, 0.3, 0.4))]
|
||||||
|
|
||||||
|
monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with image containing multiple objects
|
||||||
|
def test_kosmos2_with_image_containing_multiple_objects(
|
||||||
|
kosmos2, sample_image, monkeypatch
|
||||||
|
):
|
||||||
|
def mock_extract_entities(text):
|
||||||
|
return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))]
|
||||||
|
|
||||||
|
monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with image containing no objects
|
||||||
|
def test_kosmos2_with_image_containing_no_objects(kosmos2, sample_image, monkeypatch):
|
||||||
|
def mock_extract_entities(text):
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities)
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with a valid YouTube video URL
|
||||||
|
def test_kosmos2_with_valid_youtube_video_url(kosmos2):
|
||||||
|
youtube_video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
|
||||||
|
detections = kosmos2(video_url=youtube_video_url)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
assert (
|
||||||
|
len(detections.xyxy)
|
||||||
|
== len(detections.class_id)
|
||||||
|
== len(detections.confidence)
|
||||||
|
== 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with an invalid YouTube video URL
|
||||||
|
def test_kosmos2_with_invalid_youtube_video_url(kosmos2):
|
||||||
|
invalid_youtube_video_url = "https://www.youtube.com/invalid_video"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(video_url=invalid_youtube_video_url)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 with no YouTube video URL provided
|
||||||
|
def test_kosmos2_with_no_youtube_video_url(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2(video_url=None)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 installation
|
||||||
|
def test_kosmos2_installation():
|
||||||
|
kosmos2 = Kosmos2()
|
||||||
|
kosmos2.install()
|
||||||
|
assert os.path.exists("video.mp4")
|
||||||
|
assert os.path.exists("video.mp3")
|
||||||
|
os.remove("video.mp4")
|
||||||
|
os.remove("video.mp3")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 termination
|
||||||
|
def test_kosmos2_termination(kosmos2):
|
||||||
|
kosmos2.terminate()
|
||||||
|
assert kosmos2.process is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 start_process method
|
||||||
|
def test_kosmos2_start_process(kosmos2):
|
||||||
|
kosmos2.start_process()
|
||||||
|
assert kosmos2.process is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 preprocess_code method
|
||||||
|
def test_kosmos2_preprocess_code(kosmos2):
|
||||||
|
code = "print('Hello, World!')"
|
||||||
|
preprocessed_code = kosmos2.preprocess_code(code)
|
||||||
|
assert isinstance(preprocessed_code, str)
|
||||||
|
assert "end_of_execution" in preprocessed_code
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 run method with debug mode
|
||||||
|
def test_kosmos2_run_with_debug_mode(kosmos2, sample_image):
|
||||||
|
kosmos2.debug_mode = True
|
||||||
|
detections = kosmos2(img=sample_image)
|
||||||
|
assert isinstance(detections, Detections)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 handle_stream_output method
|
||||||
|
def test_kosmos2_handle_stream_output(kosmos2):
|
||||||
|
stream_output = "Sample output"
|
||||||
|
kosmos2.handle_stream_output(stream_output, is_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 run method with invalid image path
|
||||||
|
def test_kosmos2_run_with_invalid_image_path(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2.run(img="invalid_image_path.jpg")
|
||||||
|
|
||||||
|
|
||||||
|
# Test Kosmos2 run method with invalid video URL
|
||||||
|
def test_kosmos2_run_with_invalid_video_url(kosmos2):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
kosmos2.run(video_url="invalid_video_url")
|
||||||
|
|
||||||
|
|
||||||
|
# ... (more tests)
|
@ -0,0 +1,206 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import whisperx
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from pytube import YouTube
|
||||||
|
from your_module import SpeechToText
|
||||||
|
|
||||||
|
from swarms.models.whisperx import SpeechToText
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture to create a temporary directory for testing
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
yield tempdir
|
||||||
|
|
||||||
|
|
||||||
|
# Mock subprocess.run to prevent actual installation during tests
|
||||||
|
@patch.object(subprocess, "run")
|
||||||
|
def test_speech_to_text_install(mock_run):
|
||||||
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||||
|
stt.install()
|
||||||
|
mock_run.assert_called_with(["pip", "install", "whisperx"])
|
||||||
|
|
||||||
|
|
||||||
|
# Mock pytube.YouTube and pytube.Streams for download tests
|
||||||
|
@patch("pytube.YouTube")
|
||||||
|
@patch.object(YouTube, "streams")
|
||||||
|
def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir):
|
||||||
|
# Mock YouTube and streams
|
||||||
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
||||||
|
mock_stream = mock_streams().filter().first()
|
||||||
|
mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4")
|
||||||
|
mock_youtube.return_value = mock_youtube
|
||||||
|
mock_youtube.streams = mock_streams
|
||||||
|
|
||||||
|
stt = SpeechToText(video_url)
|
||||||
|
audio_file = stt.download_youtube_video()
|
||||||
|
|
||||||
|
assert os.path.exists(audio_file)
|
||||||
|
assert audio_file.endswith(".mp3")
|
||||||
|
|
||||||
|
|
||||||
|
# Mock whisperx.load_model and whisperx.load_audio for transcribe tests
|
||||||
|
@patch("whisperx.load_model")
|
||||||
|
@patch("whisperx.load_audio")
|
||||||
|
@patch("whisperx.load_align_model")
|
||||||
|
@patch("whisperx.align")
|
||||||
|
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
||||||
|
def test_speech_to_text_transcribe_youtube_video(
|
||||||
|
mock_diarization,
|
||||||
|
mock_align,
|
||||||
|
mock_align_model,
|
||||||
|
mock_load_audio,
|
||||||
|
mock_load_model,
|
||||||
|
temp_dir,
|
||||||
|
):
|
||||||
|
# Mock whisperx functions
|
||||||
|
mock_load_model.return_value = mock_load_model
|
||||||
|
mock_load_model.transcribe.return_value = {
|
||||||
|
"language": "en",
|
||||||
|
"segments": [{"text": "Hello, World!"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_load_audio.return_value = "audio_path"
|
||||||
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||||
|
mock_align.return_value = {"segments": [{"text": "Hello, World!"}]}
|
||||||
|
|
||||||
|
# Mock diarization pipeline
|
||||||
|
mock_diarization.return_value = None
|
||||||
|
|
||||||
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
|
||||||
|
stt = SpeechToText(video_url)
|
||||||
|
transcription = stt.transcribe_youtube_video()
|
||||||
|
|
||||||
|
assert transcription == "Hello, World!"
|
||||||
|
|
||||||
|
|
||||||
|
# More tests for different scenarios and edge cases can be added here.
|
||||||
|
|
||||||
|
|
||||||
|
# Test transcribe method with provided audio file
|
||||||
|
def test_speech_to_text_transcribe_audio_file(temp_dir):
|
||||||
|
# Create a temporary audio file
|
||||||
|
audio_file = os.path.join(temp_dir, "test_audio.mp3")
|
||||||
|
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
|
||||||
|
|
||||||
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||||
|
transcription = stt.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert transcription == ""
|
||||||
|
|
||||||
|
|
||||||
|
# Test transcribe method when Whisperx fails
|
||||||
|
@patch("whisperx.load_model")
|
||||||
|
@patch("whisperx.load_audio")
|
||||||
|
def test_speech_to_text_transcribe_whisperx_failure(
|
||||||
|
mock_load_audio, mock_load_model, temp_dir
|
||||||
|
):
|
||||||
|
# Mock whisperx functions to raise an exception
|
||||||
|
mock_load_model.side_effect = Exception("Whisperx failed")
|
||||||
|
mock_load_audio.return_value = "audio_path"
|
||||||
|
|
||||||
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||||
|
transcription = stt.transcribe("audio_path")
|
||||||
|
|
||||||
|
assert transcription == "Whisperx failed"
|
||||||
|
|
||||||
|
|
||||||
|
# Test transcribe method with missing 'segments' key in Whisperx output
|
||||||
|
@patch("whisperx.load_model")
|
||||||
|
@patch("whisperx.load_audio")
|
||||||
|
@patch("whisperx.load_align_model")
|
||||||
|
@patch("whisperx.align")
|
||||||
|
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
||||||
|
def test_speech_to_text_transcribe_missing_segments(
|
||||||
|
mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model
|
||||||
|
):
|
||||||
|
# Mock whisperx functions to return incomplete output
|
||||||
|
mock_load_model.return_value = mock_load_model
|
||||||
|
mock_load_model.transcribe.return_value = {"language": "en"}
|
||||||
|
|
||||||
|
mock_load_audio.return_value = "audio_path"
|
||||||
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||||
|
mock_align.return_value = {}
|
||||||
|
|
||||||
|
# Mock diarization pipeline
|
||||||
|
mock_diarization.return_value = None
|
||||||
|
|
||||||
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||||
|
transcription = stt.transcribe("audio_path")
|
||||||
|
|
||||||
|
assert transcription == ""
|
||||||
|
|
||||||
|
|
||||||
|
# Test transcribe method with Whisperx align failure
|
||||||
|
@patch("whisperx.load_model")
|
||||||
|
@patch("whisperx.load_audio")
|
||||||
|
@patch("whisperx.load_align_model")
|
||||||
|
@patch("whisperx.align")
|
||||||
|
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
||||||
|
def test_speech_to_text_transcribe_align_failure(
|
||||||
|
mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model
|
||||||
|
):
|
||||||
|
# Mock whisperx functions to raise an exception during align
|
||||||
|
mock_load_model.return_value = mock_load_model
|
||||||
|
mock_load_model.transcribe.return_value = {
|
||||||
|
"language": "en",
|
||||||
|
"segments": [{"text": "Hello, World!"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_load_audio.return_value = "audio_path"
|
||||||
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||||
|
mock_align.side_effect = Exception("Align failed")
|
||||||
|
|
||||||
|
# Mock diarization pipeline
|
||||||
|
mock_diarization.return_value = None
|
||||||
|
|
||||||
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||||
|
transcription = stt.transcribe("audio_path")
|
||||||
|
|
||||||
|
assert transcription == "Align failed"
|
||||||
|
|
||||||
|
|
||||||
|
# Test transcribe_youtube_video when Whisperx diarization fails
|
||||||
|
@patch("pytube.YouTube")
|
||||||
|
@patch.object(YouTube, "streams")
|
||||||
|
@patch("whisperx.DiarizationPipeline")
|
||||||
|
@patch("whisperx.load_audio")
|
||||||
|
@patch("whisperx.load_align_model")
|
||||||
|
@patch("whisperx.align")
|
||||||
|
def test_speech_to_text_transcribe_diarization_failure(
|
||||||
|
mock_align,
|
||||||
|
mock_align_model,
|
||||||
|
mock_load_audio,
|
||||||
|
mock_diarization,
|
||||||
|
mock_streams,
|
||||||
|
mock_youtube,
|
||||||
|
temp_dir,
|
||||||
|
):
|
||||||
|
# Mock YouTube and streams
|
||||||
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
||||||
|
mock_stream = mock_streams().filter().first()
|
||||||
|
mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4")
|
||||||
|
mock_youtube.return_value = mock_youtube
|
||||||
|
mock_youtube.streams = mock_streams
|
||||||
|
|
||||||
|
# Mock whisperx functions
|
||||||
|
mock_load_audio.return_value = "audio_path"
|
||||||
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||||
|
mock_align.return_value = {"segments": [{"text": "Hello, World!"}]}
|
||||||
|
|
||||||
|
# Mock diarization pipeline to raise an exception
|
||||||
|
mock_diarization.side_effect = Exception("Diarization failed")
|
||||||
|
|
||||||
|
stt = SpeechToText(video_url)
|
||||||
|
transcription = stt.transcribe_youtube_video()
|
||||||
|
|
||||||
|
assert transcription == "Diarization failed"
|
||||||
|
|
||||||
|
|
||||||
|
# Add more tests for other scenarios and edge cases as needed.
|
@ -0,0 +1,802 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from swarms.tools.tool import BaseTool, Runnable, StructuredTool, Tool, tool
|
||||||
|
|
||||||
|
# Define test data
|
||||||
|
test_input = {"key1": "value1", "key2": "value2"}
|
||||||
|
expected_output = "expected_output_value"
|
||||||
|
|
||||||
|
# Test with global variables
|
||||||
|
global_var = "global"
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for BaseTool
|
||||||
|
def test_base_tool_init():
|
||||||
|
# Test BaseTool initialization
|
||||||
|
tool = BaseTool()
|
||||||
|
assert isinstance(tool, BaseTool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_tool_invoke():
|
||||||
|
# Test BaseTool invoke method
|
||||||
|
tool = BaseTool()
|
||||||
|
result = tool.invoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for Tool
|
||||||
|
def test_tool_init():
|
||||||
|
# Test Tool initialization
|
||||||
|
tool = Tool()
|
||||||
|
assert isinstance(tool, Tool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_invoke():
|
||||||
|
# Test Tool invoke method
|
||||||
|
tool = Tool()
|
||||||
|
result = tool.invoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for StructuredTool
|
||||||
|
def test_structured_tool_init():
|
||||||
|
# Test StructuredTool initialization
|
||||||
|
tool = StructuredTool()
|
||||||
|
assert isinstance(tool, StructuredTool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_invoke():
|
||||||
|
# Test StructuredTool invoke method
|
||||||
|
tool = StructuredTool()
|
||||||
|
result = tool.invoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
# Test additional functionality and edge cases as needed
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_creation():
|
||||||
|
tool = Tool(name="test_tool", func=lambda x: x, description="Test tool")
|
||||||
|
assert tool.name == "test_tool"
|
||||||
|
assert tool.func is not None
|
||||||
|
assert tool.description == "Test tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_ainvoke():
|
||||||
|
tool = Tool(name="test_tool", func=lambda x: x, description="Test tool")
|
||||||
|
result = tool.ainvoke("input_data")
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_ainvoke_with_coroutine():
|
||||||
|
async def async_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", coroutine=async_function, description="Test tool")
|
||||||
|
result = tool.ainvoke("input_data")
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_args():
|
||||||
|
def sample_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", func=sample_function, description="Test tool")
|
||||||
|
assert tool.args == {"tool_input": {"type": "string"}}
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for StructuredTool class
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_creation():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Test tool",
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
assert tool.name == "test_tool"
|
||||||
|
assert tool.func is not None
|
||||||
|
assert tool.description == "Test tool"
|
||||||
|
assert tool.args_schema == SampleArgsSchema
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Test tool",
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
result = tool.ainvoke({"tool_input": "input_data"})
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke_with_coroutine():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
coroutine=async_function,
|
||||||
|
description="Test tool",
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
result = tool.ainvoke({"tool_input": "input_data"})
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_args():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=sample_function,
|
||||||
|
description="Test tool",
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
assert tool.args == {"tool_input": {"type": "string"}}
|
||||||
|
|
||||||
|
|
||||||
|
# Additional tests for exception handling
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_ainvoke_exception():
|
||||||
|
tool = Tool(name="test_tool", func=None, description="Test tool")
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
tool.ainvoke("input_data")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_ainvoke_with_coroutine_exception():
|
||||||
|
tool = Tool(name="test_tool", coroutine=None, description="Test tool")
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
tool.ainvoke("input_data")
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke_exception():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=None,
|
||||||
|
description="Test tool",
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
tool.ainvoke({"tool_input": "input_data"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke_with_coroutine_exception():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
coroutine=None,
|
||||||
|
description="Test tool",
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
tool.ainvoke({"tool_input": "input_data"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_description_not_provided():
|
||||||
|
tool = Tool(name="test_tool", func=lambda x: x)
|
||||||
|
assert tool.name == "test_tool"
|
||||||
|
assert tool.func is not None
|
||||||
|
assert tool.description == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_invoke_with_callbacks():
|
||||||
|
def sample_function(input_data, callbacks=None):
|
||||||
|
if callbacks:
|
||||||
|
callbacks.on_start()
|
||||||
|
callbacks.on_finish()
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", func=sample_function)
|
||||||
|
callbacks = MagicMock()
|
||||||
|
result = tool.invoke("input_data", callbacks=callbacks)
|
||||||
|
assert result == "input_data"
|
||||||
|
callbacks.on_start.assert_called_once()
|
||||||
|
callbacks.on_finish.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_invoke_with_new_argument():
|
||||||
|
def sample_function(input_data, callbacks=None):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", func=sample_function)
|
||||||
|
result = tool.invoke("input_data", callbacks=None)
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_ainvoke_with_new_argument():
|
||||||
|
async def async_function(input_data, callbacks=None):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", coroutine=async_function)
|
||||||
|
result = tool.ainvoke("input_data", callbacks=None)
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_description_from_docstring():
|
||||||
|
def sample_function(input_data):
|
||||||
|
"""Sample function docstring"""
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", func=sample_function)
|
||||||
|
assert tool.description == "Sample function docstring"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_ainvoke_with_exceptions():
|
||||||
|
async def async_function(input_data):
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
|
||||||
|
tool = Tool(name="test_tool", coroutine=async_function)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool.ainvoke("input_data")
|
||||||
|
|
||||||
|
|
||||||
|
# Additional tests for StructuredTool class
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_infer_schema_false():
|
||||||
|
def sample_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=sample_function,
|
||||||
|
args_schema=None,
|
||||||
|
infer_schema=False,
|
||||||
|
)
|
||||||
|
assert tool.args_schema is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke_with_callbacks():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample_function(input_data, callbacks=None):
|
||||||
|
if callbacks:
|
||||||
|
callbacks.on_start()
|
||||||
|
callbacks.on_finish()
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=sample_function,
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
callbacks = MagicMock()
|
||||||
|
result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks)
|
||||||
|
assert result == "input_data"
|
||||||
|
callbacks.on_start.assert_called_once()
|
||||||
|
callbacks.on_finish.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_description_not_provided():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=lambda x: x,
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
assert tool.name == "test_tool"
|
||||||
|
assert tool.func is not None
|
||||||
|
assert tool.description == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_args_schema():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=sample_function,
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
assert tool.args_schema == SampleArgsSchema
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_args_schema_inference():
|
||||||
|
def sample_function(input_data):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=sample_function,
|
||||||
|
args_schema=None,
|
||||||
|
infer_schema=True,
|
||||||
|
)
|
||||||
|
assert tool.args_schema is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke_with_new_argument():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample_function(input_data, callbacks=None):
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
func=sample_function,
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None)
|
||||||
|
assert result == "input_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_ainvoke_with_exceptions():
|
||||||
|
class SampleArgsSchema:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_function(input_data):
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
coroutine=async_function,
|
||||||
|
args_schema=SampleArgsSchema,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool.ainvoke({"tool_input": "input_data"})
|
||||||
|
|
||||||
|
|
||||||
|
# Test additional functionality and edge cases
|
||||||
|
def test_tool_with_fixture(some_fixture):
|
||||||
|
# Test Tool with a fixture
|
||||||
|
tool = Tool()
|
||||||
|
result = tool.invoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_with_fixture(some_fixture):
|
||||||
|
# Test StructuredTool with a fixture
|
||||||
|
tool = StructuredTool()
|
||||||
|
result = tool.invoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_tool_verbose_logging(caplog):
|
||||||
|
# Test verbose logging in BaseTool
|
||||||
|
tool = BaseTool(verbose=True)
|
||||||
|
result = tool.invoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
assert "Verbose logging" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_exception_handling():
|
||||||
|
# Test exception handling in Tool
|
||||||
|
tool = Tool()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
tool.invoke(test_input, raise_exception=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_async_invoke():
|
||||||
|
# Test asynchronous invoke in StructuredTool
|
||||||
|
tool = StructuredTool()
|
||||||
|
result = tool.ainvoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_async_invoke_with_fixture(some_fixture):
|
||||||
|
# Test asynchronous invoke with a fixture in Tool
|
||||||
|
tool = Tool()
|
||||||
|
result = tool.ainvoke(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
# Add more tests for specific functionalities and edge cases as needed
|
||||||
|
# Import necessary libraries and modules
|
||||||
|
|
||||||
|
|
||||||
|
# Example of a mock function to be used in testing
|
||||||
|
def mock_function(arg: str) -> str:
|
||||||
|
"""A simple mock function for testing."""
|
||||||
|
return f"Processed {arg}"
|
||||||
|
|
||||||
|
|
||||||
|
# Example of a Runnable class for testing
|
||||||
|
class MockRunnable(Runnable):
|
||||||
|
# Define necessary methods and properties
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture for creating a mock function
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_func():
|
||||||
|
return mock_function
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture for creating a Runnable instance
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_runnable():
|
||||||
|
return MockRunnable()
|
||||||
|
|
||||||
|
|
||||||
|
# Basic functionality tests
|
||||||
|
def test_tool_with_callable(mock_func):
|
||||||
|
# Test creating a tool with a simple callable
|
||||||
|
tool_instance = tool(mock_func)
|
||||||
|
assert isinstance(tool_instance, BaseTool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_with_runnable(mock_runnable):
|
||||||
|
# Test creating a tool with a Runnable instance
|
||||||
|
tool_instance = tool(mock_runnable)
|
||||||
|
assert isinstance(tool_instance, BaseTool)
|
||||||
|
|
||||||
|
|
||||||
|
# ... more basic functionality tests ...
|
||||||
|
|
||||||
|
|
||||||
|
# Argument handling tests
|
||||||
|
def test_tool_with_invalid_argument():
|
||||||
|
# Test passing an invalid argument type
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool(123) # Using an integer instead of a string/callable/Runnable
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_with_multiple_arguments(mock_func):
|
||||||
|
# Test passing multiple valid arguments
|
||||||
|
tool_instance = tool("mock", mock_func)
|
||||||
|
assert isinstance(tool_instance, BaseTool)
|
||||||
|
|
||||||
|
|
||||||
|
# ... more argument handling tests ...
|
||||||
|
|
||||||
|
|
||||||
|
# Schema inference and application tests
|
||||||
|
class TestSchema(BaseModel):
|
||||||
|
arg: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_with_args_schema(mock_func):
|
||||||
|
# Test passing a custom args_schema
|
||||||
|
tool_instance = tool(mock_func, args_schema=TestSchema)
|
||||||
|
assert tool_instance.args_schema == TestSchema
|
||||||
|
|
||||||
|
|
||||||
|
# ... more schema tests ...
|
||||||
|
|
||||||
|
|
||||||
|
# Exception handling tests
|
||||||
|
def test_tool_function_without_docstring():
|
||||||
|
# Test that a ValueError is raised if the function lacks a docstring
|
||||||
|
def no_doc_func(arg: str) -> str:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool(no_doc_func)
|
||||||
|
|
||||||
|
|
||||||
|
# ... more exception tests ...
|
||||||
|
|
||||||
|
|
||||||
|
# Decorator behavior tests
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_tool_function():
|
||||||
|
# Test an async function with the tool decorator
|
||||||
|
@tool
|
||||||
|
async def async_func(arg: str) -> str:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
# Add async specific assertions here
|
||||||
|
|
||||||
|
|
||||||
|
# ... more decorator tests ...
|
||||||
|
|
||||||
|
|
||||||
|
class MockSchema(BaseModel):
|
||||||
|
"""Mock schema for testing args_schema."""
|
||||||
|
|
||||||
|
arg: str
|
||||||
|
|
||||||
|
|
||||||
|
# Test suite starts here
|
||||||
|
class TestTool:
|
||||||
|
# Basic Functionality Tests
|
||||||
|
def test_tool_with_valid_callable_creates_base_tool(self, mock_func):
|
||||||
|
result = tool(mock_func)
|
||||||
|
assert isinstance(result, BaseTool)
|
||||||
|
|
||||||
|
def test_tool_returns_correct_function_name(self, mock_func):
|
||||||
|
result = tool(mock_func)
|
||||||
|
assert result.func.__name__ == "mock_function"
|
||||||
|
|
||||||
|
# Argument Handling Tests
|
||||||
|
def test_tool_with_string_and_runnable(self, mock_runnable):
|
||||||
|
result = tool("mock_runnable", mock_runnable)
|
||||||
|
assert isinstance(result, BaseTool)
|
||||||
|
|
||||||
|
def test_tool_raises_error_with_invalid_arguments(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool(123)
|
||||||
|
|
||||||
|
# Schema Inference and Application Tests
|
||||||
|
def test_tool_with_args_schema(self, mock_func):
|
||||||
|
result = tool(mock_func, args_schema=MockSchema)
|
||||||
|
assert result.args_schema == MockSchema
|
||||||
|
|
||||||
|
def test_tool_with_infer_schema_true(self, mock_func):
|
||||||
|
tool(mock_func, infer_schema=True)
|
||||||
|
# Assertions related to schema inference
|
||||||
|
|
||||||
|
# Return Direct Feature Tests
|
||||||
|
def test_tool_with_return_direct_true(self, mock_func):
|
||||||
|
tool(mock_func, return_direct=True)
|
||||||
|
# Assertions for return_direct behavior
|
||||||
|
|
||||||
|
# Error Handling Tests
|
||||||
|
def test_tool_raises_error_without_docstring(self):
|
||||||
|
def no_doc_func(arg: str) -> str:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool(no_doc_func)
|
||||||
|
|
||||||
|
def test_tool_raises_error_runnable_without_object_schema(self, mock_runnable):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool(mock_runnable)
|
||||||
|
|
||||||
|
# Decorator Behavior Tests
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_tool_function(self):
|
||||||
|
@tool
|
||||||
|
async def async_func(arg: str) -> str:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
# Assertions for async behavior
|
||||||
|
|
||||||
|
# Integration with StructuredTool and Tool Classes
|
||||||
|
def test_integration_with_structured_tool(self, mock_func):
|
||||||
|
result = tool(mock_func)
|
||||||
|
assert isinstance(result, StructuredTool)
|
||||||
|
|
||||||
|
# Concurrency and Async Handling Tests
|
||||||
|
def test_concurrency_in_tool(self, mock_func):
|
||||||
|
# Test related to concurrency
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Mocking and Isolation Tests
|
||||||
|
def test_mocking_external_dependencies(self, mocker):
|
||||||
|
# Use mocker to mock external dependencies
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_tool_with_different_return_types(self):
|
||||||
|
@tool
|
||||||
|
def return_int(arg: str) -> int:
|
||||||
|
return int(arg)
|
||||||
|
|
||||||
|
result = return_int("123")
|
||||||
|
assert isinstance(result, int)
|
||||||
|
assert result == 123
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def return_bool(arg: str) -> bool:
|
||||||
|
return arg.lower() in ["true", "yes"]
|
||||||
|
|
||||||
|
result = return_bool("true")
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Test with multiple arguments
|
||||||
|
def test_tool_with_multiple_args(self):
|
||||||
|
@tool
|
||||||
|
def concat_strings(a: str, b: str) -> str:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
result = concat_strings("Hello", "World")
|
||||||
|
assert result == "HelloWorld"
|
||||||
|
|
||||||
|
# Test handling of optional arguments
|
||||||
|
def test_tool_with_optional_args(self):
|
||||||
|
@tool
|
||||||
|
def greet(name: str, greeting: str = "Hello") -> str:
|
||||||
|
return f"{greeting} {name}"
|
||||||
|
|
||||||
|
assert greet("Alice") == "Hello Alice"
|
||||||
|
assert greet("Alice", greeting="Hi") == "Hi Alice"
|
||||||
|
|
||||||
|
# Test with variadic arguments
|
||||||
|
def test_tool_with_variadic_args(self):
|
||||||
|
@tool
|
||||||
|
def sum_numbers(*numbers: int) -> int:
|
||||||
|
return sum(numbers)
|
||||||
|
|
||||||
|
assert sum_numbers(1, 2, 3) == 6
|
||||||
|
assert sum_numbers(10, 20) == 30
|
||||||
|
|
||||||
|
# Test with keyword arguments
|
||||||
|
def test_tool_with_kwargs(self):
|
||||||
|
@tool
|
||||||
|
def build_query(**kwargs) -> str:
|
||||||
|
return "&".join(f"{k}={v}" for k, v in kwargs.items())
|
||||||
|
|
||||||
|
assert build_query(a=1, b=2) == "a=1&b=2"
|
||||||
|
assert build_query(foo="bar") == "foo=bar"
|
||||||
|
|
||||||
|
# Test with mixed types of arguments
|
||||||
|
def test_tool_with_mixed_args(self):
|
||||||
|
@tool
|
||||||
|
def mixed_args(a: int, b: str, *args, **kwargs) -> str:
|
||||||
|
return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}"
|
||||||
|
|
||||||
|
assert mixed_args(1, "b", "c", "d", x="y", z="w") == "1b2y-w"
|
||||||
|
|
||||||
|
# Test error handling with incorrect types
|
||||||
|
def test_tool_error_with_incorrect_types(self):
|
||||||
|
@tool
|
||||||
|
def add_numbers(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
add_numbers("1", "2")
|
||||||
|
|
||||||
|
# Test with nested tools
|
||||||
|
def test_nested_tools(self):
|
||||||
|
@tool
|
||||||
|
def inner_tool(arg: str) -> str:
|
||||||
|
return f"Inner {arg}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def outer_tool(arg: str) -> str:
|
||||||
|
return f"Outer {inner_tool(arg)}"
|
||||||
|
|
||||||
|
assert outer_tool("Test") == "Outer Inner Test"
|
||||||
|
|
||||||
|
def test_tool_with_global_variable(self):
|
||||||
|
@tool
|
||||||
|
def access_global(arg: str) -> str:
|
||||||
|
return f"{global_var} {arg}"
|
||||||
|
|
||||||
|
assert access_global("Var") == "global Var"
|
||||||
|
|
||||||
|
# Test with environment variables
|
||||||
|
def test_tool_with_env_variables(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("TEST_VAR", "Environment")
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def access_env_variable(arg: str) -> str:
|
||||||
|
import os
|
||||||
|
|
||||||
|
return f"{os.environ['TEST_VAR']} {arg}"
|
||||||
|
|
||||||
|
assert access_env_variable("Var") == "Environment Var"
|
||||||
|
|
||||||
|
# ... [Previous test cases] ...
|
||||||
|
|
||||||
|
# Test with complex data structures
|
||||||
|
def test_tool_with_complex_data_structures(self):
|
||||||
|
@tool
|
||||||
|
def process_data(data: dict) -> list:
|
||||||
|
return [data[key] for key in sorted(data.keys())]
|
||||||
|
|
||||||
|
result = process_data({"b": 2, "a": 1})
|
||||||
|
assert result == [1, 2]
|
||||||
|
|
||||||
|
# Test handling exceptions within the tool function
|
||||||
|
def test_tool_handling_internal_exceptions(self):
|
||||||
|
@tool
|
||||||
|
def function_that_raises(arg: str):
|
||||||
|
if arg == "error":
|
||||||
|
raise ValueError("Error occurred")
|
||||||
|
return arg
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
function_that_raises("error")
|
||||||
|
assert function_that_raises("ok") == "ok"
|
||||||
|
|
||||||
|
# Test with functions returning None
|
||||||
|
def test_tool_with_none_return(self):
|
||||||
|
@tool
|
||||||
|
def return_none(arg: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert return_none("anything") is None
|
||||||
|
|
||||||
|
# Test with lambda functions
|
||||||
|
def test_tool_with_lambda(self):
|
||||||
|
tool_lambda = tool(lambda x: x * 2)
|
||||||
|
assert tool_lambda(3) == 6
|
||||||
|
|
||||||
|
# Test with class methods
|
||||||
|
def test_tool_with_class_method(self):
|
||||||
|
class MyClass:
|
||||||
|
@tool
|
||||||
|
def method(self, arg: str) -> str:
|
||||||
|
return f"Method {arg}"
|
||||||
|
|
||||||
|
obj = MyClass()
|
||||||
|
assert obj.method("test") == "Method test"
|
||||||
|
|
||||||
|
# Test tool function with inheritance
|
||||||
|
def test_tool_with_inheritance(self):
|
||||||
|
class Parent:
|
||||||
|
@tool
|
||||||
|
def parent_method(self, arg: str) -> str:
|
||||||
|
return f"Parent {arg}"
|
||||||
|
|
||||||
|
class Child(Parent):
|
||||||
|
@tool
|
||||||
|
def child_method(self, arg: str) -> str:
|
||||||
|
return f"Child {arg}"
|
||||||
|
|
||||||
|
child_obj = Child()
|
||||||
|
assert child_obj.parent_method("test") == "Parent test"
|
||||||
|
assert child_obj.child_method("test") == "Child test"
|
||||||
|
|
||||||
|
# Test with decorators stacking
|
||||||
|
def test_tool_with_multiple_decorators(self):
|
||||||
|
def another_decorator(func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return f"Decorated {func(*args, **kwargs)}"
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@tool
|
||||||
|
@another_decorator
|
||||||
|
def decorated_function(arg: str):
|
||||||
|
return f"Function {arg}"
|
||||||
|
|
||||||
|
assert decorated_function("test") == "Decorated Function test"
|
||||||
|
|
||||||
|
# Test tool function when used in a multi-threaded environment
|
||||||
|
def test_tool_in_multithreaded_environment(self):
|
||||||
|
import threading
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def threaded_function(arg: int) -> int:
|
||||||
|
return arg * 2
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def thread_target():
|
||||||
|
results.append(threaded_function(5))
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=thread_target) for _ in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert results == [10] * 10
|
||||||
|
|
||||||
|
# Test with recursive functions
|
||||||
|
def test_tool_with_recursive_function(self):
|
||||||
|
@tool
|
||||||
|
def recursive_function(n: int) -> int:
|
||||||
|
if n == 0:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return n + recursive_function(n - 1)
|
||||||
|
|
||||||
|
assert recursive_function(5) == 15
|
||||||
|
|
||||||
|
|
||||||
|
# Additional tests can be added here to cover more scenarios
|
@ -0,0 +1,246 @@
|
|||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import pytest
|
||||||
|
import subprocess
|
||||||
|
from swarms.utils.code_interpreter import BaseCodeInterpreter, SubprocessCodeInterpreter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def subprocess_code_interpreter():
|
||||||
|
interpreter = SubprocessCodeInterpreter()
|
||||||
|
interpreter.start_cmd = "python -c"
|
||||||
|
yield interpreter
|
||||||
|
interpreter.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_code_interpreter_init():
|
||||||
|
interpreter = BaseCodeInterpreter()
|
||||||
|
assert isinstance(interpreter, BaseCodeInterpreter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_code_interpreter_run_not_implemented():
|
||||||
|
interpreter = BaseCodeInterpreter()
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
interpreter.run("code")
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_code_interpreter_terminate_not_implemented():
|
||||||
|
interpreter = BaseCodeInterpreter()
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
interpreter.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_init(subprocess_code_interpreter):
|
||||||
|
assert isinstance(subprocess_code_interpreter, SubprocessCodeInterpreter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_start_process(subprocess_code_interpreter):
|
||||||
|
subprocess_code_interpreter.start_process()
|
||||||
|
assert subprocess_code_interpreter.process is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_terminate(subprocess_code_interpreter):
|
||||||
|
subprocess_code_interpreter.start_process()
|
||||||
|
subprocess_code_interpreter.terminate()
|
||||||
|
assert subprocess_code_interpreter.process.poll() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_success(subprocess_code_interpreter):
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("Hello, World!" in output.get("output", "") for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter):
|
||||||
|
code = 'print("Hello, World")\nraise ValueError("Error!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("Error!" in output.get("output", "") for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_keyboard_interrupt(
|
||||||
|
subprocess_code_interpreter,
|
||||||
|
):
|
||||||
|
code = 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise KeyboardInterrupt'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("KeyboardInterrupt" in output.get("output", "") for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_max_retries(
|
||||||
|
subprocess_code_interpreter, monkeypatch
|
||||||
|
):
|
||||||
|
def mock_subprocess_popen(*args, **kwargs):
|
||||||
|
raise subprocess.CalledProcessError(1, "mocked_cmd")
|
||||||
|
|
||||||
|
monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen)
|
||||||
|
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any(
|
||||||
|
"Maximum retries reached. Could not execute code." in output.get("output", "")
|
||||||
|
for output in result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_retry_on_error(
|
||||||
|
subprocess_code_interpreter, monkeypatch
|
||||||
|
):
|
||||||
|
def mock_subprocess_popen(*args, **kwargs):
|
||||||
|
nonlocal popen_count
|
||||||
|
if popen_count == 0:
|
||||||
|
popen_count += 1
|
||||||
|
raise subprocess.CalledProcessError(1, "mocked_cmd")
|
||||||
|
else:
|
||||||
|
return subprocess.Popen(
|
||||||
|
"echo 'Hello, World!'",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen)
|
||||||
|
popen_count = 0
|
||||||
|
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("Hello, World!" in output.get("output", "") for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
# Add more tests to cover other aspects of the code and edge cases as needed
|
||||||
|
|
||||||
|
# Import statements and fixtures from the previous code block
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_line_postprocessor(subprocess_code_interpreter):
|
||||||
|
line = "This is a test line"
|
||||||
|
processed_line = subprocess_code_interpreter.line_postprocessor(line)
|
||||||
|
assert processed_line == line # No processing, should remain the same
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_preprocess_code(subprocess_code_interpreter):
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
preprocessed_code = subprocess_code_interpreter.preprocess_code(code)
|
||||||
|
assert preprocessed_code == code # No preprocessing, should remain the same
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_detect_active_line(subprocess_code_interpreter):
|
||||||
|
line = "Active line: 5"
|
||||||
|
active_line = subprocess_code_interpreter.detect_active_line(line)
|
||||||
|
assert active_line == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_detect_end_of_execution(
|
||||||
|
subprocess_code_interpreter,
|
||||||
|
):
|
||||||
|
line = "Execution completed."
|
||||||
|
end_of_execution = subprocess_code_interpreter.detect_end_of_execution(line)
|
||||||
|
assert end_of_execution is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_debug_mode(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
subprocess_code_interpreter.debug_mode = True
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Running code:\n" in captured.out
|
||||||
|
assert "Received output line:\n" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_no_debug_mode(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
subprocess_code_interpreter.debug_mode = False
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Running code:\n" not in captured.out
|
||||||
|
assert "Received output line:\n" not in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_empty_output_queue(
|
||||||
|
subprocess_code_interpreter,
|
||||||
|
):
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert not any("active_line" in output for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_handle_stream_output_stdout(
|
||||||
|
subprocess_code_interpreter,
|
||||||
|
):
|
||||||
|
line = "This is a test line"
|
||||||
|
subprocess_code_interpreter.handle_stream_output(threading.current_thread(), False)
|
||||||
|
subprocess_code_interpreter.process.stdout.write(line + "\n")
|
||||||
|
subprocess_code_interpreter.process.stdout.flush()
|
||||||
|
time.sleep(0.1)
|
||||||
|
output = subprocess_code_interpreter.output_queue.get()
|
||||||
|
assert output["output"] == line
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_handle_stream_output_stderr(
|
||||||
|
subprocess_code_interpreter,
|
||||||
|
):
|
||||||
|
line = "This is an error line"
|
||||||
|
subprocess_code_interpreter.handle_stream_output(threading.current_thread(), True)
|
||||||
|
subprocess_code_interpreter.process.stderr.write(line + "\n")
|
||||||
|
subprocess_code_interpreter.process.stderr.flush()
|
||||||
|
time.sleep(0.1)
|
||||||
|
output = subprocess_code_interpreter.output_queue.get()
|
||||||
|
assert output["output"] == line
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_preprocess_code(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
subprocess_code_interpreter.preprocess_code = (
|
||||||
|
lambda x: x.upper()
|
||||||
|
) # Modify code in preprocess_code
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("Hello, World!" in output.get("output", "") for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_exception(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
code = 'print("Hello, World!")'
|
||||||
|
subprocess_code_interpreter.start_cmd = (
|
||||||
|
"nonexistent_command" # Force an exception during subprocess creation
|
||||||
|
)
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any(
|
||||||
|
"Maximum retries reached" in output.get("output", "") for output in result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_active_line(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
code = "a = 5\nprint(a)" # Contains an active line
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any(output.get("active_line") == 5 for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_end_of_execution(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
code = 'print("Hello, World!")' # Simple code without active line marker
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any(output.get("active_line") is None for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_multiple_lines(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
code = "a = 5\nb = 10\nprint(a + b)"
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("15" in output.get("output", "") for output in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_code_interpreter_run_with_unicode_characters(
|
||||||
|
subprocess_code_interpreter, capsys
|
||||||
|
):
|
||||||
|
code = 'print("こんにちは、世界")' # Contains unicode characters
|
||||||
|
result = list(subprocess_code_interpreter.run(code))
|
||||||
|
assert any("こんにちは、世界" in output.get("output", "") for output in result)
|
Loading…
Reference in new issue