You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.4 KiB
81 lines
2.4 KiB
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
from swarms.chunkers.base_chunker import BaseChunker
|
|
from swarms.chunkers.text_chunker import TextChunker
|
|
from swarms.utils.exponential_backoff import ExponentialBackoffMixin
|
|
|
|
from swarms.artifacts.text_artifact import TextArtifact
|
|
from swarms.tokenizers.base_tokenizer import BaseTokenizer
|
|
|
|
|
|
@dataclass
|
|
class BaseEmbeddingModel(
|
|
ExponentialBackoffMixin,
|
|
ABC,
|
|
# SerializableMixin
|
|
):
|
|
"""
|
|
Attributes:
|
|
model: The name of the model to use.
|
|
tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
|
|
"""
|
|
|
|
model: str = None
|
|
tokenizer: Optional[BaseTokenizer] = None
|
|
chunker: BaseChunker = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.tokenizer:
|
|
self.chunker = TextChunker(tokenizer=self.tokenizer)
|
|
|
|
def embed_text_artifact(
|
|
self, artifact: TextArtifact
|
|
) -> list[float]:
|
|
return self.embed_string(artifact.to_text())
|
|
|
|
def embed_string(self, string: str) -> list[float]:
|
|
for attempt in self.retrying():
|
|
with attempt:
|
|
if (
|
|
self.tokenizer
|
|
and self.tokenizer.count_tokens(string)
|
|
> self.tokenizer.max_tokens
|
|
):
|
|
return self._embed_long_string(string)
|
|
else:
|
|
return self.try_embed_chunk(string)
|
|
|
|
else:
|
|
raise RuntimeError("Failed to embed string.")
|
|
|
|
@abstractmethod
|
|
def try_embed_chunk(self, chunk: str) -> list[float]:
|
|
...
|
|
|
|
def _embed_long_string(self, string: str) -> list[float]:
|
|
"""Embeds a string that is too long to embed in one go."""
|
|
chunks = self.chunker.chunk(string)
|
|
|
|
embedding_chunks = []
|
|
length_chunks = []
|
|
for chunk in chunks:
|
|
embedding_chunks.append(self.try_embed_chunk(chunk.value))
|
|
length_chunks.append(len(chunk))
|
|
|
|
# generate weighted averages
|
|
embedding_chunks = np.average(
|
|
embedding_chunks, axis=0, weights=length_chunks
|
|
)
|
|
|
|
# normalize length to 1
|
|
embedding_chunks = embedding_chunks / np.linalg.norm(
|
|
embedding_chunks
|
|
)
|
|
|
|
return embedding_chunks.tolist()
|