parent
2de1853609
commit
ac57d09bd5
@ -1,3 +1,38 @@
|
||||
#init ocean
|
||||
# TODO upload ocean to pip and config it to the abstract class
|
||||
import logging
|
||||
from typing import Union, List
|
||||
|
||||
import oceandb
|
||||
from oceandb.utils.embedding_function import MultiModalEmbeddingFunction
|
||||
|
||||
class OceanDB:
|
||||
def __init__(self):
|
||||
try:
|
||||
self.client = oceandb.Client()
|
||||
print(self.client.heartbeat())
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize OceanDB client. Error: {e}")
|
||||
|
||||
def create_collection(self, collection_name: str, modality: str):
|
||||
try:
|
||||
embedding_function = MultiModalEmbeddingFunction(modality=modality)
|
||||
collection = self.client.create_collection(collection_name, embedding_function=embedding_function)
|
||||
return collection
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to create collection. Error {e}")
|
||||
|
||||
def add_documents(self, collection, documents: List[str], ids: List[str]):
|
||||
try:
|
||||
return collection.add(documents=documents, ids=ids)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add documents to collection. Error: {e}")
|
||||
raise
|
||||
|
||||
def query(self, collection, query_texts: list[str], n_results: int):
|
||||
try:
|
||||
results = collection.query(query_texts=query_texts, n_results=n_results)
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to query the collection. Error {e}")
|
||||
raise
|
@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
from pegasus import Pegasus
|
||||
|
||||
# import oceandb
|
||||
# from oceandb.utils.embedding_functions import MultiModalEmbeddingfunction
|
||||
|
||||
|
||||
class PegasusEmbedding:
|
||||
def __init__(self, modality: str, multi_process: bool = False, n_processes: int = 4):
|
||||
self.modality = modality
|
||||
self.multi_process = multi_process
|
||||
self.n_processes = n_processes
|
||||
try:
|
||||
self.pegasus = Pegasus(modality, multi_process, n_processes)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize Pegasus with modality: {modality}: {e}")
|
||||
raise
|
||||
|
||||
def embed(self, data: Union[str, list[str]]):
|
||||
try:
|
||||
return self.pegasus.embed(data)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to generate embeddings. Error: {e}")
|
||||
raise
|
||||
|
Loading…
Reference in new issue