You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.7 KiB
93 lines
3.7 KiB
from typing import List, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from langchain.embeddings.base import Embeddings
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
# Set a default model here
|
|
DEFAULT_MODEL_NAME = "avsolatorio/NoInstruct-small-Embedding-v0"
|
|
|
|
|
|
class CustomHuggingFaceEmbeddings(Embeddings):
|
|
"""
|
|
A custom embeddings class that wraps a Hugging Face model for generating embeddings.
|
|
|
|
Supports two modes:
|
|
- "sentence": uses the [CLS] token representation for sentence/document embeddings.
|
|
- "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings.
|
|
"""
|
|
|
|
def __init__(
|
|
self, model_name: str = DEFAULT_MODEL_NAME, default_mode: str = "sentence"
|
|
):
|
|
self.model_name = model_name
|
|
# Set device to GPU if available, else CPU
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.default_mode = default_mode # "sentence" or "query"
|
|
self.model.eval() # Set model to evaluation mode
|
|
|
|
def get_embedding(self, text: Union[str, List[str]], mode: str = None):
|
|
if mode is None:
|
|
mode = self.default_mode
|
|
assert mode in (
|
|
"query",
|
|
"sentence",
|
|
), f"Unsupported mode: {mode}. Only 'query' and 'sentence' are supported."
|
|
|
|
# Ensure we are working with a list of texts
|
|
if isinstance(text, str):
|
|
text = [text]
|
|
|
|
# Tokenize the input texts
|
|
inp = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
|
# Move the input tensors to the same device as the model
|
|
inp = {key: value.to(self.device) for key, value in inp.items()}
|
|
|
|
# Forward pass (no gradients needed)
|
|
with torch.no_grad():
|
|
output = self.model(**inp)
|
|
|
|
if mode == "query":
|
|
# Mean pooling: weight by attention mask and average across tokens
|
|
vectors = output.last_hidden_state * inp["attention_mask"].unsqueeze(2)
|
|
vectors = vectors.sum(dim=1) / inp["attention_mask"].sum(dim=-1).view(-1, 1)
|
|
else:
|
|
# Sentence/document embedding: use the [CLS] token (first token) representation
|
|
vectors = output.last_hidden_state[:, 0, :]
|
|
return vectors
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Compute embeddings for a list of documents (using sentence mode).
|
|
"""
|
|
vectors = self.get_embedding(texts, mode="sentence")
|
|
return vectors.cpu().numpy().tolist()
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""
|
|
Compute an embedding for a single query.
|
|
"""
|
|
vector = self.get_embedding(text, mode="query")
|
|
return vector.cpu().numpy()[0].tolist()
|
|
|
|
|
|
# For quick testing
|
|
if __name__ == "__main__":
|
|
embeddings = CustomHuggingFaceEmbeddings()
|
|
|
|
# Example texts for document embeddings
|
|
texts = [
|
|
"Illustration of the REaLTabFormer model. The left block shows the non-relational tabular data model using GPT-2.",
|
|
"Predicting human mobility holds significant practical value, with applications in disaster planning and epidemic simulation.",
|
|
"As economies adopt digital technologies, policy makers are asking how to prepare the workforce for emerging labor demands.",
|
|
]
|
|
doc_embeddings = embeddings.embed_documents(texts)
|
|
print("Document embeddings:", doc_embeddings)
|
|
|
|
# Example query embedding
|
|
query_embedding = embeddings.embed_query("Which sentence talks about jobs?")
|
|
print("Query embedding:", query_embedding)
|