You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/memory/utils.py

75 lines
2.4 KiB

"""Utility functions for working with vectors and vectorstores."""
from enum import Enum
from typing import List, Tuple, Type
import numpy as np
from swarms.structs.document import Document
from swarms.memory.cosine_similarity import cosine_similarity
class DistanceStrategy(str, Enum):
"""Enumerator of the Distance strategies for calculating distances
between vectors."""
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
DOT_PRODUCT = "DOT_PRODUCT"
JACCARD = "JACCARD"
COSINE = "COSINE"
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Calculate maximal marginal relevance."""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
query_embedding = np.expand_dims(query_embedding, axis=0)
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
most_similar = int(np.argmax(similarity_to_query))
idxs = [most_similar]
selected = np.array([embedding_list[most_similar]])
while len(idxs) < min(k, len(embedding_list)):
best_score = -np.inf
idx_to_add = -1
similarity_to_selected = cosine_similarity(embedding_list, selected)
for i, query_score in enumerate(similarity_to_query):
if i in idxs:
continue
redundant_score = max(similarity_to_selected[i])
equation_score = (
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score:
best_score = equation_score
idx_to_add = i
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs
def filter_complex_metadata(
documents: List[Document],
*,
allowed_types: Tuple[Type, ...] = (str, bool, int, float)
) -> List[Document]:
"""Filter out metadata types that are not supported for a vector store."""
updated_documents = []
for document in documents:
filtered_metadata = {}
for key, value in document.metadata.items():
if not isinstance(value, allowed_types):
continue
filtered_metadata[key] = value
document.metadata = filtered_metadata
updated_documents.append(document)
return updated_documents