parent
c9f479983c
commit
c33ceb4327
@ -1,5 +0,0 @@
|
||||
from swarms.chunkers.base_chunker import BaseChunker
|
||||
from swarms.chunkers.chunk_seperator import ChunkSeparator
|
||||
from swarms.chunkers.text_chunker import TextChunker
|
||||
|
||||
__all__ = ["ChunkSeparator", "BaseChunker", "TextChunker"]
|
@ -1,163 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from swarms.artifacts.text_artifact import TextArtifact
|
||||
from swarms.chunkers.chunk_seperator import ChunkSeparator
|
||||
from swarms.tokenizers.base_tokenizer import BaseTokenizer
|
||||
from swarms.tokenizers.openai_tokenizers import OpenAITokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseChunker(ABC):
|
||||
"""
|
||||
Base class for chunking text into smaller chunks.
|
||||
"""
|
||||
|
||||
DEFAULT_SEPARATORS = [ChunkSeparator(" ")]
|
||||
|
||||
separators: list[ChunkSeparator] = field(
|
||||
default_factory=lambda: BaseChunker.DEFAULT_SEPARATORS
|
||||
)
|
||||
tokenizer: BaseTokenizer = field(
|
||||
default_factory=lambda: OpenAITokenizer(
|
||||
model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
|
||||
)
|
||||
)
|
||||
max_tokens: int = field(
|
||||
default_factory=lambda: BaseChunker.tokenizer.max_tokens
|
||||
)
|
||||
|
||||
def chunk(self, text: str | str) -> list[str]:
|
||||
"""
|
||||
Chunk the given text into smaller chunks.
|
||||
|
||||
Args:
|
||||
text (TextArtifact | str): The text to be chunked.
|
||||
|
||||
Returns:
|
||||
list[TextArtifact]: The list of chunked text artifacts.
|
||||
"""
|
||||
text = text.value if isinstance(text, str) else text
|
||||
|
||||
return [
|
||||
TextArtifact(c) for c in self._chunk_recursively(text)
|
||||
]
|
||||
|
||||
def _chunk_recursively(
|
||||
self,
|
||||
chunk: str,
|
||||
current_separator: ChunkSeparator | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Recursively chunk the given chunk into smaller subchunks.
|
||||
|
||||
Args:
|
||||
chunk (str): The chunk to be recursively chunked.
|
||||
current_separator (Optional[ChunkSeparator], optional): The current separator to be used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of recursively chunked subchunks.
|
||||
"""
|
||||
token_count = self.tokenizer.count_tokens(chunk)
|
||||
|
||||
if token_count <= self.max_tokens:
|
||||
return [chunk]
|
||||
else:
|
||||
balance_index = -1
|
||||
balance_diff = float("inf")
|
||||
tokens_count = 0
|
||||
half_token_count = token_count // 2
|
||||
|
||||
# If a separator is provided, only use separators after it.
|
||||
if current_separator:
|
||||
separators = self.separators[
|
||||
self.separators.index(current_separator) :
|
||||
]
|
||||
else:
|
||||
separators = self.separators
|
||||
|
||||
# Loop through available separators to find the best split.
|
||||
for separator in separators:
|
||||
# Split the chunk into subchunks using the current separator.
|
||||
subchunks = list(
|
||||
filter(None, chunk.split(separator.value))
|
||||
)
|
||||
|
||||
# Check if the split resulted in more than one subchunk.
|
||||
if len(subchunks) > 1:
|
||||
# Iterate through the subchunks and calculate token counts.
|
||||
for index, subchunk in enumerate(subchunks):
|
||||
if index < len(subchunks):
|
||||
if separator.is_prefix:
|
||||
subchunk = separator.value + subchunk
|
||||
else:
|
||||
subchunk = subchunk + separator.value
|
||||
|
||||
tokens_count += self.tokenizer.count_tokens(
|
||||
subchunk
|
||||
)
|
||||
|
||||
# Update the best split if the current one is more balanced.
|
||||
if (
|
||||
abs(tokens_count - half_token_count)
|
||||
< balance_diff
|
||||
):
|
||||
balance_index = index
|
||||
balance_diff = abs(
|
||||
tokens_count - half_token_count
|
||||
)
|
||||
|
||||
# Create the two subchunks based on the best separator.
|
||||
if separator.is_prefix:
|
||||
# If the separator is a prefix, append it before this subchunk.
|
||||
first_subchunk = (
|
||||
separator.value
|
||||
+ separator.value.join(
|
||||
subchunks[: balance_index + 1]
|
||||
)
|
||||
)
|
||||
second_subchunk = (
|
||||
separator.value
|
||||
+ separator.value.join(
|
||||
subchunks[balance_index + 1 :]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If the separator is not a prefix, append it after this subchunk.
|
||||
first_subchunk = (
|
||||
separator.value.join(
|
||||
subchunks[: balance_index + 1]
|
||||
)
|
||||
+ separator.value
|
||||
)
|
||||
second_subchunk = separator.value.join(
|
||||
subchunks[balance_index + 1 :]
|
||||
)
|
||||
|
||||
# Continue recursively chunking the subchunks.
|
||||
first_subchunk_rec = self._chunk_recursively(
|
||||
first_subchunk.strip(), separator
|
||||
)
|
||||
second_subchunk_rec = self._chunk_recursively(
|
||||
second_subchunk.strip(), separator
|
||||
)
|
||||
|
||||
# Return the concatenated results of the subchunks if both are non-empty.
|
||||
if first_subchunk_rec and second_subchunk_rec:
|
||||
return (
|
||||
first_subchunk_rec + second_subchunk_rec
|
||||
)
|
||||
# If only one subchunk is non-empty, return it.
|
||||
elif first_subchunk_rec:
|
||||
return first_subchunk_rec
|
||||
elif second_subchunk_rec:
|
||||
return second_subchunk_rec
|
||||
else:
|
||||
return []
|
||||
# If none of the separators result in a balanced split, split the chunk in half.
|
||||
midpoint = len(chunk) // 2
|
||||
return self._chunk_recursively(
|
||||
chunk[:midpoint]
|
||||
) + self._chunk_recursively(chunk[midpoint:])
|
@ -1,7 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkSeparator:
|
||||
value: str
|
||||
is_prefix: bool = False
|
@ -1,13 +0,0 @@
|
||||
from swarms.chunkers.base_chunker import BaseChunker
|
||||
from swarms.chunkers.chunk_seperator import ChunkSeparator
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
DEFAULT_SEPARATORS = [
|
||||
ChunkSeparator("\n\n"),
|
||||
ChunkSeparator("\n"),
|
||||
ChunkSeparator(". "),
|
||||
ChunkSeparator("! "),
|
||||
ChunkSeparator("? "),
|
||||
ChunkSeparator(" "),
|
||||
]
|
@ -1,6 +0,0 @@
|
||||
from swarms.schedulers.agent_process import (
|
||||
AgentProcess,
|
||||
AgentProcessQueue,
|
||||
)
|
||||
|
||||
__all__ = ["AgentProcess", "AgentProcessQueue"]
|
Loading…
Reference in new issue