You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.5 KiB
57 lines
1.5 KiB
import logging
|
|
from typing import Union
|
|
|
|
from pegasus import Pegasus
|
|
|
|
|
|
class PegasusEmbedding:
|
|
"""
|
|
Pegasus
|
|
|
|
Args:
|
|
modality (str): Modality to use for embedding
|
|
multi_process (bool, optional): Whether to use multi-process. Defaults to False.
|
|
n_processes (int, optional): Number of processes to use. Defaults to 4.
|
|
|
|
Usage:
|
|
--------------
|
|
pegasus = PegasusEmbedding(modality="text")
|
|
pegasus.embed("Hello world")
|
|
|
|
|
|
vision
|
|
--------------
|
|
pegasus = PegasusEmbedding(modality="vision")
|
|
pegasus.embed("https://i.imgur.com/1qZ0K8r.jpeg")
|
|
|
|
audio
|
|
--------------
|
|
pegasus = PegasusEmbedding(modality="audio")
|
|
pegasus.embed("https://www2.cs.uic.edu/~i101/SoundFiles/StarWars60.wav")
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, modality: str, multi_process: bool = False, n_processes: int = 4
|
|
):
|
|
self.modality = modality
|
|
self.multi_process = multi_process
|
|
self.n_processes = n_processes
|
|
try:
|
|
self.pegasus = Pegasus(modality, multi_process, n_processes)
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Failed to initialize Pegasus with modality: {modality}: {e}"
|
|
)
|
|
raise
|
|
|
|
def embed(self, data: Union[str, list[str]]):
|
|
"""Embed the data"""
|
|
try:
|
|
return self.pegasus.embed(data)
|
|
except Exception as e:
|
|
logging.error(f"Failed to generate embeddings. Error: {e}")
|
|
raise
|