You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/chunkers/base.py

114 lines
4.0 KiB

from __future__ import annotations
from abc import ABC
from typing import Optional
from attr import define, field, Factory
from griptape.artifacts import TextArtifact
from swarms.chunkers.chunk_seperators import ChunkSeparator
from griptape.tokenizers import OpenAiTokenizer
@define
class BaseChunker(ABC):
"""
Base Chunker
A chunker is a tool that splits a text into smaller chunks that can be processed by a language model.
Usage:
--------------
"""
DEFAULT_SEPARATORS = [ChunkSeparator(" ")]
separators: list[ChunkSeparator] = field(
default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True),
kw_only=True,
)
tokenizer: OpenAiTokenizer = field(
default=Factory(
lambda: OpenAiTokenizer(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
)
),
kw_only=True,
)
max_tokens: int = field(
default=Factory(lambda self: self.tokenizer.max_tokens, takes_self=True),
kw_only=True,
)
def chunk(self, text: TextArtifact | str) -> list[TextArtifact]:
text = text.value if isinstance(text, TextArtifact) else text
return [TextArtifact(c) for c in self._chunk_recursively(text)]
def _chunk_recursively(
self, chunk: str, current_separator: Optional[ChunkSeparator] = None
) -> list[str]:
token_count = self.tokenizer.token_count(chunk)
if token_count <= self.max_tokens:
return [chunk]
else:
balance_index = -1
balance_diff = float("inf")
tokens_count = 0
half_token_count = token_count // 2
if current_separator:
separators = self.separators[self.separators.index(current_separator) :]
else:
separators = self.separators
for separator in separators:
subchanks = list(filter(None, chunk.split(separator.value)))
if len(subchanks) > 1:
for index, subchunk in enumerate(subchanks):
if index < len(subchanks):
if separator.is_prefix:
subchunk = separator.value + subchunk
else:
subchunk = subchunk + separator.value
tokens_count += self.tokenizer.token_count(subchunk)
if abs(tokens_count - half_token_count) < balance_diff:
balance_index = index
balance_diff = abs(tokens_count - half_token_count)
if separator.is_prefix:
first_subchunk = separator.value + separator.value.join(
subchanks[: balance_index + 1]
)
second_subchunk = separator.value + separator.value.join(
subchanks[balance_index + 1 :]
)
else:
first_subchunk = (
separator.value.join(subchanks[: balance_index + 1])
+ separator.value
)
second_subchunk = separator.value.join(
subchanks[balance_index + 1 :]
)
first_subchunk_rec = self._chunk_recursively(
first_subchunk.strip(), separator
)
second_subchunk_rec = self._chunk_recursively(
second_subchunk.strip(), separator
)
if first_subchunk_rec and second_subchunk_rec:
return first_subchunk_rec + second_subchunk_rec
elif first_subchunk_rec:
return first_subchunk_rec
elif second_subchunk_rec:
return second_subchunk_rec
else:
return []
return []