Merge branch 'master' of https://github.com/kyegomez/swarms
commit
bad5aa8deb
@ -0,0 +1,146 @@
|
||||
# BaseChunker Documentation
|
||||
|
||||
## Table of Contents
|
||||
1. [Introduction](#introduction)
|
||||
2. [Overview](#overview)
|
||||
3. [Installation](#installation)
|
||||
4. [Usage](#usage)
|
||||
1. [BaseChunker Class](#basechunker-class)
|
||||
2. [Examples](#examples)
|
||||
5. [Additional Information](#additional-information)
|
||||
6. [Conclusion](#conclusion)
|
||||
|
||||
---
|
||||
|
||||
## 1. Introduction <a name="introduction"></a>
|
||||
|
||||
The `BaseChunker` module is a tool for splitting text into smaller chunks that can be processed by a language model. It is a fundamental component in natural language processing tasks that require handling long or complex text inputs.
|
||||
|
||||
This documentation provides an extensive guide on using the `BaseChunker` module, explaining its purpose, parameters, and usage.
|
||||
|
||||
---
|
||||
|
||||
## 2. Overview <a name="overview"></a>
|
||||
|
||||
The `BaseChunker` module is designed to address the challenge of processing lengthy text inputs that exceed the maximum token limit of language models. By breaking such text into smaller, manageable chunks, it enables efficient and accurate processing.
|
||||
|
||||
Key features and parameters of the `BaseChunker` module include:
|
||||
- `separators`: Specifies a list of `ChunkSeparator` objects used to split the text into chunks.
|
||||
- `tokenizer`: Defines the tokenizer to be used for counting tokens in the text.
|
||||
- `max_tokens`: Sets the maximum token limit for each chunk.
|
||||
|
||||
The `BaseChunker` module facilitates the chunking process and ensures that the generated chunks are within the token limit.
|
||||
|
||||
---
|
||||
|
||||
## 3. Installation <a name="installation"></a>
|
||||
|
||||
Before using the `BaseChunker` module, ensure you have the required dependencies installed. The module relies on `griptape` and `swarms` libraries. You can install these dependencies using pip:
|
||||
|
||||
```bash
|
||||
pip install griptape swarms
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Usage <a name="usage"></a>
|
||||
|
||||
In this section, we'll cover how to use the `BaseChunker` module effectively. It consists of the `BaseChunker` class and provides examples to demonstrate its usage.
|
||||
|
||||
### 4.1. `BaseChunker` Class <a name="basechunker-class"></a>
|
||||
|
||||
The `BaseChunker` class is the core component of the `BaseChunker` module. It is used to create a `BaseChunker` instance, which can split text into chunks efficiently.
|
||||
|
||||
#### Parameters:
|
||||
- `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the text into chunks.
|
||||
- `tokenizer` (OpenAiTokenizer): Defines the tokenizer to be used for counting tokens in the text.
|
||||
- `max_tokens` (int): Sets the maximum token limit for each chunk.
|
||||
|
||||
### 4.2. Examples <a name="examples"></a>
|
||||
|
||||
Let's explore how to use the `BaseChunker` class with different scenarios and applications.
|
||||
|
||||
#### Example 1: Basic Chunking
|
||||
|
||||
```python
|
||||
from basechunker import BaseChunker, ChunkSeparator
|
||||
|
||||
# Initialize the BaseChunker
|
||||
chunker = BaseChunker()
|
||||
|
||||
# Text to be chunked
|
||||
input_text = "This is a long text that needs to be split into smaller chunks for processing."
|
||||
|
||||
# Chunk the text
|
||||
chunks = chunker.chunk(input_text)
|
||||
|
||||
# Print the generated chunks
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
print(f"Chunk {idx}: {chunk.value}")
|
||||
```
|
||||
|
||||
#### Example 2: Custom Separators
|
||||
|
||||
```python
|
||||
from basechunker import BaseChunker, ChunkSeparator
|
||||
|
||||
# Define custom separators
|
||||
custom_separators = [ChunkSeparator(","), ChunkSeparator(";")]
|
||||
|
||||
# Initialize the BaseChunker with custom separators
|
||||
chunker = BaseChunker(separators=custom_separators)
|
||||
|
||||
# Text with custom separators
|
||||
input_text = "This text, separated by commas; should be split accordingly."
|
||||
|
||||
# Chunk the text
|
||||
chunks = chunker.chunk(input_text)
|
||||
|
||||
# Print the generated chunks
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
print(f"Chunk {idx}: {chunk.value}")
|
||||
```
|
||||
|
||||
#### Example 3: Adjusting Maximum Tokens
|
||||
|
||||
```python
|
||||
from basechunker import BaseChunker
|
||||
|
||||
# Initialize the BaseChunker with a custom maximum token limit
|
||||
chunker = BaseChunker(max_tokens=50)
|
||||
|
||||
# Long text input
|
||||
input_text = "This is an exceptionally long text that should be broken into smaller chunks based on token count."
|
||||
|
||||
# Chunk the text
|
||||
chunks = chunker.chunk(input_text)
|
||||
|
||||
# Print the generated chunks
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
print(f"Chunk {idx}: {chunk.value}")
|
||||
```
|
||||
|
||||
### 4.3. Additional Features
|
||||
|
||||
The `BaseChunker` class also provides additional features:
|
||||
|
||||
#### Recursive Chunking
|
||||
The `_chunk_recursively` method handles the recursive chunking of text, ensuring that each chunk stays within the token limit.
|
||||
|
||||
---
|
||||
|
||||
## 5. Additional Information <a name="additional-information"></a>
|
||||
|
||||
- **Text Chunking**: The `BaseChunker` module is a fundamental tool for text chunking, a crucial step in preprocessing text data for various natural language processing tasks.
|
||||
- **Custom Separators**: You can customize the separators used to split the text, allowing flexibility in how text is chunked.
|
||||
- **Token Count**: The module accurately counts tokens using the specified tokenizer, ensuring that chunks do not exceed token limits.
|
||||
|
||||
---
|
||||
|
||||
## 6. Conclusion <a name="conclusion"></a>
|
||||
|
||||
The `BaseChunker` module is an essential tool for text preprocessing and handling long or complex text inputs in natural language processing tasks. This documentation has provided a comprehensive guide on its usage, parameters, and examples, enabling you to efficiently manage and process text data by splitting it into manageable chunks.
|
||||
|
||||
By using the `BaseChunker`, you can ensure that your text data remains within token limits and is ready for further analysis and processing.
|
||||
|
||||
*Please check the official `BaseChunker` repository and documentation for any updates beyond the knowledge cutoff date.*
|
@ -0,0 +1,147 @@
|
||||
# PdfChunker Documentation
|
||||
|
||||
## Table of Contents
|
||||
1. [Introduction](#introduction)
|
||||
2. [Overview](#overview)
|
||||
3. [Installation](#installation)
|
||||
4. [Usage](#usage)
|
||||
1. [PdfChunker Class](#pdfchunker-class)
|
||||
2. [Examples](#examples)
|
||||
5. [Additional Information](#additional-information)
|
||||
6. [Conclusion](#conclusion)
|
||||
|
||||
---
|
||||
|
||||
## 1. Introduction <a name="introduction"></a>
|
||||
|
||||
The `PdfChunker` module is a specialized tool designed to split PDF text content into smaller, more manageable chunks. It is a valuable asset for processing PDF documents in natural language processing and text analysis tasks.
|
||||
|
||||
This documentation provides a comprehensive guide on how to use the `PdfChunker` module. It covers its purpose, parameters, and usage, ensuring that you can effectively process PDF text content.
|
||||
|
||||
---
|
||||
|
||||
## 2. Overview <a name="overview"></a>
|
||||
|
||||
The `PdfChunker` module serves a critical role in handling PDF text content, which is often lengthy and complex. Key features and parameters of the `PdfChunker` module include:
|
||||
|
||||
- `separators`: Specifies a list of `ChunkSeparator` objects used to split the PDF text content into chunks.
|
||||
- `tokenizer`: Defines the tokenizer used for counting tokens in the text.
|
||||
- `max_tokens`: Sets the maximum token limit for each chunk.
|
||||
|
||||
By using the `PdfChunker`, you can efficiently prepare PDF text content for further analysis and processing.
|
||||
|
||||
---
|
||||
|
||||
## 3. Installation <a name="installation"></a>
|
||||
|
||||
Before using the `PdfChunker` module, ensure you have the required dependencies installed. The module relies on the `swarms` library. You can install this dependency using pip:
|
||||
|
||||
```bash
|
||||
pip install swarms
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Usage <a name="usage"></a>
|
||||
|
||||
In this section, we'll explore how to use the `PdfChunker` module effectively. It consists of the `PdfChunker` class and provides examples to demonstrate its usage.
|
||||
|
||||
### 4.1. `PdfChunker` Class <a name="pdfchunker-class"></a>
|
||||
|
||||
The `PdfChunker` class is the core component of the `PdfChunker` module. It is used to create a `PdfChunker` instance, which can split PDF text content into manageable chunks.
|
||||
|
||||
#### Parameters:
|
||||
- `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the PDF text content into chunks.
|
||||
- `tokenizer` (OpenAiTokenizer): Defines the tokenizer used for counting tokens in the text.
|
||||
- `max_tokens` (int): Sets the maximum token limit for each chunk.
|
||||
|
||||
### 4.2. Examples <a name="examples"></a>
|
||||
|
||||
Let's explore how to use the `PdfChunker` class with different scenarios and applications.
|
||||
|
||||
#### Example 1: Basic Chunking
|
||||
|
||||
```python
|
||||
from swarms.chunkers.pdf_chunker import PdfChunker
|
||||
from swarms.chunkers.chunk_seperator import ChunkSeparator
|
||||
|
||||
# Initialize the PdfChunker
|
||||
pdf_chunker = PdfChunker()
|
||||
|
||||
# PDF text content to be chunked
|
||||
pdf_text = "This is a PDF document with multiple paragraphs and sentences. It should be split into smaller chunks for analysis."
|
||||
|
||||
# Chunk the PDF text content
|
||||
chunks = pdf_chunker.chunk(pdf_text)
|
||||
|
||||
# Print the generated chunks
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
print(f"Chunk {idx}:\n{chunk.value}")
|
||||
```
|
||||
|
||||
#### Example 2: Custom Separators
|
||||
|
||||
```python
|
||||
from swarms.chunkers.pdf_chunker import PdfChunker
|
||||
from swarms.chunkers.chunk_seperator import ChunkSeparator
|
||||
|
||||
# Define custom separators for PDF chunking
|
||||
custom_separators = [ChunkSeparator("\n\n"), ChunkSeparator(". ")]
|
||||
|
||||
# Initialize the PdfChunker with custom separators
|
||||
pdf_chunker = PdfChunker(separators=custom_separators)
|
||||
|
||||
# PDF text content with custom separators
|
||||
pdf_text = "This PDF document has custom paragraph separators.\n\nIt also uses period-based sentence separators. Split accordingly."
|
||||
|
||||
# Chunk the PDF text content
|
||||
chunks = pdf_chunker.chunk(pdf_text)
|
||||
|
||||
# Print the generated chunks
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
print(f"Chunk {idx}:\n{chunk.value}")
|
||||
```
|
||||
|
||||
#### Example 3: Adjusting Maximum Tokens
|
||||
|
||||
```python
|
||||
from swarms.chunkers.pdf_chunker import PdfChunker
|
||||
|
||||
# Initialize the PdfChunker with a custom maximum token limit
|
||||
pdf_chunker = PdfChunker(max_tokens=50)
|
||||
|
||||
# Lengthy PDF text content
|
||||
pdf_text = "This is an exceptionally long PDF document that should be broken into smaller chunks based on token count."
|
||||
|
||||
# Chunk the PDF text content
|
||||
chunks = pdf_chunker.chunk(pdf_text)
|
||||
|
||||
# Print the generated chunks
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
print(f"Chunk {idx}:\n{chunk.value}")
|
||||
```
|
||||
|
||||
### 4.3. Additional Features
|
||||
|
||||
The `PdfChunker` class also provides additional features:
|
||||
|
||||
#### Recursive Chunking
|
||||
The `_chunk_recursively` method handles the recursive chunking of PDF text content, ensuring that each chunk stays within the token limit.
|
||||
|
||||
---
|
||||
|
||||
## 5. Additional Information <a name="additional-information"></a>
|
||||
|
||||
- **PDF Text Chunking**: The `PdfChunker` module is a specialized tool for splitting PDF text content into manageable chunks, making it suitable for natural language processing tasks involving PDF documents.
|
||||
- **Custom Separators**: You can customize separators to adapt the PDF text content chunking process to specific document structures.
|
||||
- **Token Count**: The module accurately counts tokens using the specified tokenizer, ensuring that chunks do not exceed token limits.
|
||||
|
||||
---
|
||||
|
||||
## 6. Conclusion <a name="conclusion"></a>
|
||||
|
||||
The `PdfChunker` module is a valuable asset for processing PDF text content in various natural language processing and text analysis tasks. This documentation has provided a comprehensive guide on its usage, parameters, and examples, ensuring that you can effectively prepare PDF documents for further analysis and processing.
|
||||
|
||||
By using the `PdfChunker`, you can efficiently break down lengthy and complex PDF text content into manageable chunks, making it ready for in-depth analysis.
|
||||
|
||||
*Please check the official `PdfChunker` repository and documentation for any updates beyond the knowledge cutoff date.*
|
@ -1,8 +0,0 @@
|
||||
attrs==21.2.0
|
||||
griptape==0.18.2
|
||||
oceandb==0.1.0
|
||||
pgvector==0.2.3
|
||||
pydantic==1.10.8
|
||||
SQLAlchemy==1.4.49
|
||||
SQLAlchemy==2.0.20
|
||||
swarms==1.8.2
|
@ -0,0 +1,69 @@
|
||||
import pytest
|
||||
from swarms.chunkers.base import BaseChunker, TextArtifact, ChunkSeparator, OpenAiTokenizer # adjust the import paths accordingly
|
||||
|
||||
# 1. Test Initialization
|
||||
def test_chunker_initialization():
|
||||
chunker = BaseChunker()
|
||||
assert isinstance(chunker, BaseChunker)
|
||||
assert chunker.max_tokens == chunker.tokenizer.max_tokens
|
||||
|
||||
def test_default_separators():
|
||||
chunker = BaseChunker()
|
||||
assert chunker.separators == BaseChunker.DEFAULT_SEPARATORS
|
||||
|
||||
def test_default_tokenizer():
|
||||
chunker = BaseChunker()
|
||||
assert isinstance(chunker.tokenizer, OpenAiTokenizer)
|
||||
|
||||
# 2. Test Basic Chunking
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected_output",
|
||||
[
|
||||
("This is a test.", [TextArtifact("This is a test.")]),
|
||||
("Hello World!", [TextArtifact("Hello World!")]),
|
||||
# Add more simple cases
|
||||
],
|
||||
)
|
||||
def test_basic_chunk(input_text, expected_output):
|
||||
chunker = BaseChunker()
|
||||
result = chunker.chunk(input_text)
|
||||
assert result == expected_output
|
||||
|
||||
# 3. Test Chunking with Different Separators
|
||||
def test_custom_separators():
|
||||
custom_separator = ChunkSeparator(";")
|
||||
chunker = BaseChunker(separators=[custom_separator])
|
||||
input_text = "Hello;World!"
|
||||
expected_output = [TextArtifact("Hello;"), TextArtifact("World!")]
|
||||
result = chunker.chunk(input_text)
|
||||
assert result == expected_output
|
||||
|
||||
# 4. Test Recursive Chunking
|
||||
def test_recursive_chunking():
|
||||
chunker = BaseChunker(max_tokens=5)
|
||||
input_text = "This is a more complex text."
|
||||
expected_output = [
|
||||
TextArtifact("This"),
|
||||
TextArtifact("is a"),
|
||||
TextArtifact("more"),
|
||||
TextArtifact("complex"),
|
||||
TextArtifact("text.")
|
||||
]
|
||||
result = chunker.chunk(input_text)
|
||||
assert result == expected_output
|
||||
|
||||
# 5. Test Edge Cases and Special Scenarios
|
||||
def test_empty_text():
|
||||
chunker = BaseChunker()
|
||||
result = chunker.chunk("")
|
||||
assert result == []
|
||||
|
||||
def test_whitespace_text():
|
||||
chunker = BaseChunker()
|
||||
result = chunker.chunk(" ")
|
||||
assert result == [TextArtifact(" ")]
|
||||
|
||||
def test_single_word():
|
||||
chunker = BaseChunker()
|
||||
result = chunker.chunk("Hello")
|
||||
assert result == [TextArtifact("Hello")]
|
@ -0,0 +1,96 @@
|
||||
# tests/test_fuyu.py
|
||||
|
||||
import pytest
|
||||
from swarms.models import Fuyu
|
||||
from transformers import FuyuProcessor, FuyuImageProcessor
|
||||
from PIL import Image
|
||||
|
||||
# Basic test to ensure instantiation of class.
|
||||
def test_fuyu_initialization():
|
||||
fuyu_instance = Fuyu()
|
||||
assert isinstance(fuyu_instance, Fuyu)
|
||||
|
||||
# Using parameterized testing for different init parameters.
|
||||
@pytest.mark.parametrize(
|
||||
"pretrained_path, device_map, max_new_tokens",
|
||||
[
|
||||
("adept/fuyu-8b", "cuda:0", 7),
|
||||
("adept/fuyu-8b", "cpu", 10),
|
||||
],
|
||||
)
|
||||
def test_fuyu_parameters(pretrained_path, device_map, max_new_tokens):
|
||||
fuyu_instance = Fuyu(pretrained_path, device_map, max_new_tokens)
|
||||
assert fuyu_instance.pretrained_path == pretrained_path
|
||||
assert fuyu_instance.device_map == device_map
|
||||
assert fuyu_instance.max_new_tokens == max_new_tokens
|
||||
|
||||
# Fixture for creating a Fuyu instance.
|
||||
@pytest.fixture
|
||||
def fuyu_instance():
|
||||
return Fuyu()
|
||||
|
||||
# Test using the fixture.
|
||||
def test_fuyu_processor_initialization(fuyu_instance):
|
||||
assert isinstance(fuyu_instance.processor, FuyuProcessor)
|
||||
assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor)
|
||||
|
||||
# Test exception when providing an invalid image path.
|
||||
def test_invalid_image_path(fuyu_instance):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
fuyu_instance("Hello", "invalid/path/to/image.png")
|
||||
|
||||
# Using monkeypatch to replace the Image.open method to simulate a failure.
|
||||
def test_image_open_failure(fuyu_instance, monkeypatch):
|
||||
|
||||
def mock_open(*args, **kwargs):
|
||||
raise Exception("Mocked failure")
|
||||
|
||||
monkeypatch.setattr(Image, "open", mock_open)
|
||||
|
||||
with pytest.raises(Exception, match="Mocked failure"):
|
||||
fuyu_instance("Hello", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D")
|
||||
|
||||
# Marking a slow test.
|
||||
@pytest.mark.slow
|
||||
def test_fuyu_model_output(fuyu_instance):
|
||||
# This is a dummy test and may not be functional without real data.
|
||||
output = fuyu_instance("Hello, my name is", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D")
|
||||
assert isinstance(output, str)
|
||||
|
||||
def test_tokenizer_type(fuyu_instance):
|
||||
assert "tokenizer" in dir(fuyu_instance)
|
||||
|
||||
def test_processor_has_image_processor_and_tokenizer(fuyu_instance):
|
||||
assert fuyu_instance.processor.image_processor == fuyu_instance.image_processor
|
||||
assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer
|
||||
|
||||
def test_model_device_map(fuyu_instance):
|
||||
assert fuyu_instance.model.device_map == fuyu_instance.device_map
|
||||
|
||||
# Testing maximum tokens setting
|
||||
def test_max_new_tokens_setting(fuyu_instance):
|
||||
assert fuyu_instance.max_new_tokens == 7
|
||||
|
||||
# Test if an exception is raised when invalid text is provided.
|
||||
def test_invalid_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance(None, "path/to/image.png")
|
||||
|
||||
# Test if an exception is raised when empty text is provided.
|
||||
def test_empty_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance("", "path/to/image.png")
|
||||
|
||||
# Test if an exception is raised when a very long text is provided.
|
||||
def test_very_long_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance("A" * 10000, "path/to/image.png")
|
||||
|
||||
# Check model's default device map
|
||||
def test_default_device_map():
|
||||
fuyu_instance = Fuyu()
|
||||
assert fuyu_instance.device_map == "cuda:0"
|
||||
|
||||
# Testing if processor is correctly initialized
|
||||
def test_processor_initialization(fuyu_instance):
|
||||
assert isinstance(fuyu_instance.processor, FuyuProcessor)
|
@ -0,0 +1,99 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import torch
|
||||
from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor
|
||||
|
||||
@pytest.fixture
|
||||
def idefics_instance():
|
||||
with patch("torch.cuda.is_available", return_value=False): # Assuming tests are run on CPU for simplicity
|
||||
instance = Idefics()
|
||||
return instance
|
||||
|
||||
# Basic Tests
|
||||
def test_init_default(idefics_instance):
|
||||
assert idefics_instance.device == "cpu"
|
||||
assert idefics_instance.max_length == 100
|
||||
assert not idefics_instance.chat_history
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device,expected",
|
||||
[
|
||||
(None, "cpu"),
|
||||
("cuda", "cuda"),
|
||||
("cpu", "cpu"),
|
||||
]
|
||||
)
|
||||
def test_init_device(device, expected):
|
||||
with patch("torch.cuda.is_available", return_value=True if expected == "cuda" else False):
|
||||
instance = Idefics(device=device)
|
||||
assert instance.device == expected
|
||||
|
||||
# Test `run` method
|
||||
def test_run(idefics_instance):
|
||||
prompts = [["User: Test"]]
|
||||
with patch.object(idefics_instance, "processor") as mock_processor, \
|
||||
patch.object(idefics_instance, "model") as mock_model:
|
||||
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
|
||||
mock_model.generate.return_value = torch.tensor([1, 2, 3])
|
||||
mock_processor.batch_decode.return_value = ["Test"]
|
||||
|
||||
result = idefics_instance.run(prompts)
|
||||
|
||||
assert result == ["Test"]
|
||||
|
||||
# Test `__call__` method (using the same logic as run for simplicity)
|
||||
def test_call(idefics_instance):
|
||||
prompts = [["User: Test"]]
|
||||
with patch.object(idefics_instance, "processor") as mock_processor, \
|
||||
patch.object(idefics_instance, "model") as mock_model:
|
||||
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
|
||||
mock_model.generate.return_value = torch.tensor([1, 2, 3])
|
||||
mock_processor.batch_decode.return_value = ["Test"]
|
||||
|
||||
result = idefics_instance(prompts)
|
||||
|
||||
assert result == ["Test"]
|
||||
|
||||
# Test `chat` method
|
||||
def test_chat(idefics_instance):
|
||||
user_input = "User: Hello"
|
||||
response = "Model: Hi there!"
|
||||
with patch.object(idefics_instance, "run", return_value=[response]):
|
||||
result = idefics_instance.chat(user_input)
|
||||
|
||||
assert result == response
|
||||
assert idefics_instance.chat_history == [user_input, response]
|
||||
|
||||
# Test `set_checkpoint` method
|
||||
def test_set_checkpoint(idefics_instance):
|
||||
new_checkpoint = "new_checkpoint"
|
||||
with patch.object(IdeficsForVisionText2Text, "from_pretrained") as mock_from_pretrained, \
|
||||
patch.object(AutoProcessor, "from_pretrained"):
|
||||
idefics_instance.set_checkpoint(new_checkpoint)
|
||||
|
||||
mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16)
|
||||
|
||||
# Test `set_device` method
|
||||
def test_set_device(idefics_instance):
|
||||
new_device = "cuda"
|
||||
with patch.object(idefics_instance.model, "to"):
|
||||
idefics_instance.set_device(new_device)
|
||||
|
||||
assert idefics_instance.device == new_device
|
||||
|
||||
# Test `set_max_length` method
|
||||
def test_set_max_length(idefics_instance):
|
||||
new_length = 150
|
||||
idefics_instance.set_max_length(new_length)
|
||||
assert idefics_instance.max_length == new_length
|
||||
|
||||
# Test `clear_chat_history` method
|
||||
def test_clear_chat_history(idefics_instance):
|
||||
idefics_instance.chat_history = ["User: Test", "Model: Response"]
|
||||
idefics_instance.clear_chat_history()
|
||||
assert not idefics_instance.chat_history
|
||||
|
||||
# Exception Tests
|
||||
def test_run_with_empty_prompts(idefics_instance):
|
||||
with pytest.raises(Exception): # Replace Exception with the actual exception that may arise for an empty prompt.
|
||||
idefics_instance.run([])
|
@ -0,0 +1,74 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from swarms.models.vilt import Vilt, Image, requests
|
||||
|
||||
# Fixture for Vilt instance
|
||||
@pytest.fixture
|
||||
def vilt_instance():
|
||||
return Vilt()
|
||||
|
||||
# 1. Test Initialization
|
||||
def test_vilt_initialization(vilt_instance):
|
||||
assert isinstance(vilt_instance, Vilt)
|
||||
assert vilt_instance.processor is not None
|
||||
assert vilt_instance.model is not None
|
||||
|
||||
# 2. Test Model Predictions
|
||||
@patch.object(requests, 'get')
|
||||
@patch.object(Image, 'open')
|
||||
def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance):
|
||||
mock_image = Mock()
|
||||
mock_image_open.return_value = mock_image
|
||||
mock_requests_get.return_value.raw = Mock()
|
||||
|
||||
# It's a mock response, so no real answer expected
|
||||
with pytest.raises(Exception): # Ensure exception is more specific
|
||||
vilt_instance("What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80")
|
||||
|
||||
# 3. Test Exception Handling for network
|
||||
@patch.object(requests, 'get', side_effect=requests.RequestException("Network error"))
|
||||
def test_vilt_network_exception(vilt_instance):
|
||||
with pytest.raises(requests.RequestException):
|
||||
vilt_instance("What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80")
|
||||
|
||||
# Parameterized test cases for different inputs
|
||||
@pytest.mark.parametrize(
|
||||
"text,image_url",
|
||||
[
|
||||
("What is this?", "http://example.com/image1.jpg"),
|
||||
("Who is in the image?", "http://example.com/image2.jpg"),
|
||||
("Where was this picture taken?", "http://example.com/image3.jpg"),
|
||||
# ... Add more scenarios
|
||||
]
|
||||
)
|
||||
def test_vilt_various_inputs(text, image_url, vilt_instance):
|
||||
with pytest.raises(Exception): # Again, ensure exception is more specific
|
||||
vilt_instance(text, image_url)
|
||||
|
||||
# Test with invalid or empty text
|
||||
@pytest.mark.parametrize(
|
||||
"text,image_url",
|
||||
[
|
||||
("", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"),
|
||||
(None, "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"),
|
||||
(" ", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"),
|
||||
# ... Add more scenarios
|
||||
]
|
||||
)
|
||||
def test_vilt_invalid_text(text, image_url, vilt_instance):
|
||||
with pytest.raises(ValueError):
|
||||
vilt_instance(text, image_url)
|
||||
|
||||
# Test with invalid or empty image_url
|
||||
@pytest.mark.parametrize(
|
||||
"text,image_url",
|
||||
[
|
||||
("What is this?", ""),
|
||||
("Who is in the image?", None),
|
||||
("Where was this picture taken?", " "),
|
||||
]
|
||||
)
|
||||
def test_vilt_invalid_image_url(text, image_url, vilt_instance):
|
||||
with pytest.raises(ValueError):
|
||||
vilt_instance(text, image_url)
|
||||
|
Loading…
Reference in new issue